Open In Colab

Getting Started¶

Overview¶

This semester, all homeworks will be conducted through Google Colab notebooks. All code for the homework assignment will be written and run in this notebook. Running in Colab will automatically provide a GPU, but you may also run this notebook locally by following these instructions if you wish to use your own GPU.

You will save images in the notebooks to use and fill out a given LaTeX template which will be submitted to Gradescope, along with your notebook code.

Using Colab¶

On the left-hand side, you can click the different icons to see a Table of Contents of the assignment, as well as local files accessible through the notebook.

Make sure to go to Runtime -> Change runtime type and select GPU as the hardware accelerator. This allows you to use a GPU. Run the cells below to get started on the assignment. Note that a session is open for a maximum of 12 hours, and using too much GPU compute may result in restricted access for a short period of time. Please start the homework early so you have ample time to work.

If you load this notebook by clicking "Open in Colab" from github, you will need to save it to your own Google Drive to keep your work.

General Tips¶

In each homework problem, you will implement autoregressive models and run it on various datasets. Oftentime you will run it on two datasets (dataset 1 and dataset 2). In these cases, the expected outputs for dataset 1 are already provided to help as a sanity check.

Feel free to print whatever output (e.g. debugging code, training code, etc) you want, as the graded submission will be the submitted pdf with images.

After you complete the assignment, download all of the images outputted in the results/ folder and upload them to the figure folder in the given latex template.

There is a lot of freedom in this homework to design write and design your own models. Hyperparameters are given as a guide to show what worked for us, but feel free to explore and use what you find is best!

Run the cells below to download and load up the starter code.

In [1]:
# !if [ -d deepul ]; then rm -Rf deepul; fi
# !git clone https://github.com/rll/deepul.git 
# !unzip -qq deepul/homeworks/hw1/data/hw1_data.zip -d deepul/homeworks/hw1/data/
# !pip install ./deepul
In [1]:
import numpy as np
import copy
# import jax.numpy as np

from deepul.hw1_helper import (
    # Q1
    visualize_q1_data,
    q1_sample_data_1,
    q1_sample_data_2,
    q1_save_results,
    # Q2
    q2a_save_results,
    q2b_save_results,
    visualize_q2a_data,
    visualize_q2b_data,
    # Q3
    
    q3ab_save_results,
    q3c_save_results,
    # Q4
    q4a_save_results,
    q4b_save_results,
    # Q5
    visualize_q5_data,
    q5a_save_results,
    # Q6
    visualize_q6_data,
    q6a_save_results,
)

Question 1: 1D Data¶

In this question, we will train simple generative models on discrete 1D data.

Execute the cell below to visualize our datasets

In [5]:
visualize_q1_data(dset_type=1)
visualize_q1_data(dset_type=2)
Dataset 1
No description has been provided for this image
Dataset 2
No description has been provided for this image

Part (a) Fitting a Histogram¶

Let $\theta = (\theta_0, \dots, \theta_{d-1}) \in \mathbb{R}^{d}$ and define the model $p_\theta(x) = \frac{e^{\theta_x}}{\sum_{x'}e^{\theta_{x'}}}$

Fit $p_\theta$ with maximum likelihood via stochastic gradient descent on the training set, using $\theta$ initialized to zero. Use your favorite version of stochastic gradient descent, and optimize your hyperparameters on a validation set of your choice.

You will provide these deliverables

  1. Over the course of training, record the average negative log-likelihood (nats / dim) of the training data (per minibatch) and test data (for your entire test set). Code is provided that automatically plots the training curves.
  2. Report the final test set performance of your final model
  3. Plot the model probabilities in a bar graph with $\{0,\dots,d-1\}$ on the x-axis and a real number in $[0,1]$ on the y-axis.

Fill out the function below and return the necessary arguments. Feel free to create more cells if need be.

In [6]:
def softmax_probs(theta):
    """Computes all softmax probabilities for given theta."""
    theta_shifted = theta - np.max(theta)  # Stability trick
    exp_theta = np.exp(theta_shifted)
    return exp_theta / np.sum(exp_theta)
def q1_a(train_data, test_data, d, dset_id):
    """
    Optimized version of q1_a for better performance
    """
    train_losses = []
    test_losses = []
    theta = np.zeros(d, dtype=float)
    y = np.bincount(train_data, minlength=d) / len(train_data)
    
    print("train_data shape: ", train_data.shape)
    print("d: ", d)
    print("dset_id", dset_id)

    # Hyperparameters
    epochs = 200000
    learning_rate = 0.1
    
    # Calculate test loss only periodically
    eval_interval = 10
    
    # Pre-allocate arrays
    train_losses = np.zeros(epochs)
    test_losses = np.zeros(epochs)
    
    for epoch in range(epochs):
        probs = softmax_probs(theta)
        nll = -np.mean(np.log(probs[train_data]))
        d_nll = probs - y
        
        theta -= learning_rate * d_nll
        train_losses[epoch] = nll
        
        # Calculate test loss only periodically
        if epoch % eval_interval == 0:
            t_nll = -np.mean(np.log(probs[test_data]))
            test_losses[epoch] = t_nll
        elif epoch > 0:  # Copy the previous value for non-evaluated epochs
            test_losses[epoch] = test_losses[epoch-1]
    
    distribution = softmax_probs(theta)
    return train_losses, test_losses, distribution

Results¶

Once you've implemented q1_a, execute the cells below to visualize and save your results

In [7]:
q1_save_results(1, 'a', q1_a)
train_data shape:  (800,)
d:  20
dset_id 1
Final Test Loss: 2.5434
No description has been provided for this image
No description has been provided for this image
In [8]:
q1_save_results(2, 'a', q1_a)
train_data shape:  (8000,)
d:  100
dset_id 2
Final Test Loss: 3.6897
No description has been provided for this image
No description has been provided for this image

Part (b) Fitting Discretized Mixture of Logistics¶

Let us model $p_\theta(x)$ as a discretized mixture of 4 logistics such that $p_\theta(x) = \sum_{i=1}^4 \pi_i[\sigma((x+0.5 - \mu_i)/s_i) - \sigma((x-0.5-\mu_i)/s_i)]$

For the edge case of when $x = 0$, we replace $x-0.5$ by $-\infty$, and for $x = 99$, we replace $x+0.5$ by $\infty$.

You may find the PixelCNN++ helpful for more information on discretized mixture of logistics.

Provide the same set of corresponding deliverables as part (a)

Fill out the function below and return the necessary arguments. Feel free to create more cells if need be.

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

class MixtureOfLogistics(nn.Module):
    """Mixture of Logistics distribution model for discrete data."""
    
    def __init__(self, d, n_mix=4):
        """
        Initialize the Mixture of Logistics model.
        
        Args:
            d: Number of possible discrete values (0 to d-1)
            n_mix: Number of mixture components
        """
        super().__init__()
        self.d = d
        self.n_mix = n_mix
        
        # TODO: Initialize model parameters
        # 1. Mixture weights (logit_probs)
        # 2. Component means (means)
        # 3. Component scales (log_scales)
        self.logit_probs = nn.Parameter(torch.zeros(n_mix)) # shape (n_mix,)
        init_means = torch.linspace(0,d-1,n_mix) + torch.randn(n_mix)*0.1  #shape (n_mix,)
        self.means = nn.Parameter(torch.clamp(init_means,0,d-1))
        self.log_scales = nn.Parameter(torch.ones(n_mix)*0.5) # shape (n_mix,1)
        
    
    def forward(self, x):
        """
        Compute the log probability of each value in x.
        
        Args:
            x: tensor of shape (batch_size,) containing integers in {0, ..., d-1}
            
        Returns:
            tensor of shape (batch_size,) containing log probabilities
        """
        # TODO: Implement forward pass
        # 1. Get mixture weights using softmax on logit_probs
        # 2. Ensure scales are positive (e.g., using softplus)
        # 3. Calculate CDF at x+0.5 and x-0.5 for each component
        # 4. Handle edge cases (x=0 and x=d-1)
        # 5. Compute probabilities from CDF differences
        # 6. Weight by mixture weights and sum
        # 7. Return log probabilities
        probs = torch.softmax(self.logit_probs, dim=0)  # n_mix
        probs = probs.unsqueeze(0)  # 1, n_mix

        scales = nn.functional.softplus(self.log_scales)  # n_mix 
        scales = scales.unsqueeze(0)  # 1, n_mix

        x = x.unsqueeze(1)  # batch_size,1
        x_float = x.float()
        first_term = torch.sigmoid((x_float + 0.5 - self.means) / scales)
        second_term = torch.sigmoid((x_float - 0.5 - self.means) / scales)

        # Handle edge cases
        is_d_minus_one = (x == self.d - 1)
        is_zero = (x == 0)
        first_term = torch.where(is_d_minus_one, torch.ones_like(first_term), first_term)
        second_term = torch.where(is_zero, torch.zeros_like(second_term), second_term)

        # Calculate component probabilities
        component_probs = first_term - second_term  # batch_size, n_mix
        probabilities = probs * component_probs  # batch_size, n_mix
        probabilities = probabilities.sum(dim=1)  # batch_size

        # Return log probabilities
        return torch.log(probabilities + 1e-10)
    
    def get_distribution(self):
        """
        Returns the probability distribution over all possible values.
        
        Returns:
            numpy array of shape (d,) containing probabilities
        """
        device = next(self.parameters()).device
        x = torch.arange(self.d, device=device)
        with torch.no_grad():
            log_probs = self.forward(x)
            probs = torch.exp(log_probs)
            return (probs / torch.sum(probs)).cpu().numpy()


def q1_b(train_data, test_data, d, dset_id):
    """
    Train a mixture of logistics model on discrete data.
    """
    # Set random seed and device
    torch.manual_seed(42)
    np.random.seed(42)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Convert data to tensors
    train_tensor = torch.tensor(train_data, dtype=torch.long, device=device)
    test_tensor = torch.tensor(test_data, dtype=torch.long, device=device)
    
    # Create dataset and dataloader
    train_dataset = TensorDataset(train_tensor)
    test_dataset = TensorDataset(test_tensor)
    # Set hyperparameters based on dataset ID
    hyperparams = {
        1: {"batch_size": 800, "lr": 0.005, "num_epochs": 10000}, 
        2: {"batch_size": 800, "lr": 0.001, "num_epochs": 10000}
    }[dset_id]
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=hyperparams["batch_size"], 
        shuffle=True
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=hyperparams["batch_size"],
        shuffle=False
    )
    train_losses = np.zeros(hyperparams["num_epochs"])
    test_losses = np.zeros(hyperparams["num_epochs"])
    # Initialize model and optimizer
    model = MixtureOfLogistics(d, n_mix=4).to(device)
    optimizer = optim.Adam(model.parameters(), lr=hyperparams["lr"])
    
    # TODO: Implement training loop
    # 1. Track training and test losses
    # 2. For each epoch:
    #    a. Train the model on batches
    #    b. Compute and record loss
    #    c. Evaluate on test set
    for epoch in range(hyperparams["num_epochs"]):
        model.train()
        train_loss = 0.0
        for batch in train_loader:
            optimizer.zero_grad()
            x = batch[0].to(device)
            log_probs = model(x)
            nll = -log_probs.mean()
            nll.backward()
            
            optimizer.step()
            train_loss += nll.item()
        train_losses[epoch] = train_loss / len(train_loader)
        
        model.eval()
        test_loss = 0.0
        with torch.no_grad():
            for batch in test_loader:
                x = batch[0].to(device)
                log_probs = model(x)
                nll = -log_probs.mean()
                test_loss += nll.item()
        test_losses[epoch] = test_loss / len(test_loader)
    # Get final model probabilities
    model.eval()
    model_probs = model.get_distribution()
    print("len(model_probs)", len(model_probs))
    # Return placeholder values for now
    return train_losses, test_losses, model_probs

Results¶

Once you've implemented q1_b, execute the cells below to visualize and save your results

In [10]:
q1_save_results(1, 'b', q1_b)
len(model_probs) 20
Final Test Loss: 2.5499
No description has been provided for this image
No description has been provided for this image
In [11]:
q1_save_results(2, 'b', q1_b)
len(model_probs) 100
Final Test Loss: 4.0082
No description has been provided for this image
No description has been provided for this image

Question 2 PixelCNNs¶

Now, you will train more powerful PixelCNN models on the shapes dataset and MNIST. In addition, we will extend to modeling colored datasets.

Run the cell below to visualize the two datasets binary datasets

In [5]:
visualize_q2a_data(1)
visualize_q2a_data(2)
data_dir:  /home/nghiaph/workspace/deepul/homeworks/hw1/data
samples shape:  (100, 20, 20, 1)
No description has been provided for this image
data_dir:  /home/nghiaph/workspace/deepul/homeworks/hw1/data
samples shape:  (100, 28, 28, 1)
No description has been provided for this image

Part (a) PixelCNN on Shapes and MNIST¶

In this part, implement a simple PixelCNN architecture to model binary MNIST and shapes images (same as Q2(b), but with a PixelCNN).

We recommend the following network design:

  • A $7 \times 7$ masked type A convolution
  • $5$ $7 \times 7$ masked type B convolutions
  • $2$ $1 \times 1$ masked type B convolutions
  • Appropriate ReLU nonlinearities in-between
  • 64 convolutional filters

And the following hyperparameters:

  • Batch size 128
  • Learning rate $10^{-3}$
  • 10 epochs
  • Adam Optimizer (this applies to all PixelCNN models trained in future parts)

Your model should output logits, after which you could apply a sigmoid over 1 logit, or a softmax over two logits (either is fine). It may also help to scale your input to $[-1, 1]$ before running it through the network.

Training on the shapes dataset should be quick, and MNIST should take around 10 minutes

Checkout the Paper for more details: https://arxiv.org/abs/1601.06759

You will provide these deliverables

  1. Over the course of training, record the average negative log-likelihood (nats / dim) of the training data (per minibatch) and test data (for your entire test set). Code is provided that automatically plots the training curves.
  2. Report the final test set performance of your final model
  3. 100 samples from the final trained model

Fill out the function below and return the necessary arguments. Feel free to create more cells if need be.

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from tqdm import tqdm

class MaskedConv2d(nn.Conv2d):
    """
    Implementation of a masked convolution layer for PixelCNN.
    Masks can be of type 'A' or 'B'.
    """
    def __init__(self, in_channels, out_channels, kernel_size, mask_type='A', padding='same', **kwargs):
        super(MaskedConv2d, self).__init__(in_channels, out_channels, kernel_size, padding=padding, **kwargs)
        self.register_buffer('mask', torch.ones_like(self.weight))
        self.mask_type = mask_type
        
        # Create mask
        h, w = kernel_size, kernel_size
        if isinstance(kernel_size, int):
            h, w = kernel_size, kernel_size
            
        center_h, center_w = h // 2, w // 2
        
        # For all spatial locations
        for i in range(h):
            for j in range(w):
                # Mask out future pixels (below and to the right)
                if (i > center_h) or (i == center_h and j > center_w):
                    self.mask[:, :, i, j] = 0
                    
                # For mask type A, also mask out the center pixel
                if mask_type == 'A' and i == center_h and j == center_w:
                    self.mask[:, :, i, j] = 0
    
    def forward(self, x):
        # Apply the mask to weights
        self.weight.data *= self.mask
        return super(MaskedConv2d, self).forward(x)

class PixelCNN(nn.Module):
    """
    PixelCNN model for binary image generation.
    As recommended in the assignment:
    - One 7x7 masked type A convolution
    - Five 7x7 masked type B convolutions
    - Two 1x1 masked type B convolutions
    - ReLU nonlinearities in-between
    - 64 convolutional filters
    """
    def __init__(self, in_channels=1, hidden_dim=64):
        super(PixelCNN, self).__init__()
        
        # Initial masked convolutional layer of type A
        self.conv_a = MaskedConv2d(in_channels, hidden_dim, kernel_size=7, mask_type='A', padding='same')
        
        # Stack of masked convolutional layers of type B
        self.conv_b_stack = nn.ModuleList([
            MaskedConv2d(hidden_dim, hidden_dim, kernel_size=7, mask_type='B', padding='same')
            for _ in range(5)
        ])
        
        # Final 1x1 convolutions
        self.conv_1x1_stack = nn.ModuleList([
            MaskedConv2d(hidden_dim, hidden_dim, kernel_size=1, mask_type='B', padding='same')
            for _ in range(2)
        ])
        
        # Output layer: 1 channel for binary output (will apply sigmoid later)
        self.output_conv = MaskedConv2d(hidden_dim, 1, kernel_size=1, mask_type='B', padding='same')
        
    def forward(self, x):
        # Apply first mask A convolution
        x = F.relu(self.conv_a(x))
        
        # Apply mask B convolutions with ReLU
        for conv_b in self.conv_b_stack:
            x = F.relu(conv_b(x))
            
        # Apply 1x1 convolutions with ReLU
        for conv_1x1 in self.conv_1x1_stack:
            x = F.relu(conv_1x1(x))
            
        # Final output layer (returns logits)
        x = self.output_conv(x)
        
        return x

def sample_from_model(model, image_shape, device, num_samples=100):
    """
    Sample images from the trained model using ancestral sampling.
    """
    model.eval()
    samples = torch.zeros((num_samples, 1, image_shape[0], image_shape[1]), device=device)
    
    with torch.no_grad():
        # Generate each pixel sequentially
        for i in range(image_shape[0]):
            for j in range(image_shape[1]):
                # Get the model's prediction
                logits = model(samples)[:, :, i, j]
                # Convert logits to probabilities
                probs = torch.sigmoid(logits)
                # Sample from Bernoulli distribution
                samples[:, :, i, j] = torch.bernoulli(probs)
    
    return samples.cpu().numpy().transpose(0, 2, 3, 1)

def binary_cross_entropy_loss(logits, targets):
    """
    Compute binary cross entropy loss from logits.
    """
    return F.binary_cross_entropy_with_logits(logits, targets)

def negative_log_likelihood(logits, targets):
    """
    Compute negative log likelihood in nats per dimension.
    """
    batch_size = targets.size(0)
    n_dims = targets.size(1) * targets.size(2) * targets.size(3)
    
    # Compute binary cross entropy (already in log scale)
    bce = binary_cross_entropy_loss(logits, targets)
    
    # Convert to nats (from bits) and normalize by dimensions
    # No need for conversion as PyTorch already uses natural log
    nll = bce * n_dims
    
    return nll
In [7]:
def q2_a(train_data, test_data, image_shape, dset_id):
    """
    train_data: A (n_train, H, W, 1) uint8 numpy array of binary images with values in {0, 1}
    test_data: A (n_test, H, W, 1) uint8 numpy array of binary images with values in {0, 1}
    image_shape: (H, W), height and width of the image
    dset_id: An identifying number of which dataset is given (1 or 2). Most likely
             used to set different hyperparameters for different datasets

    Returns
    - a (# of training iterations,) numpy array of train_losses evaluated every minibatch
    - a (# of epochs + 1,) numpy array of test_losses evaluated once at initialization and after each epoch
    - a numpy array of size (100, H, W, 1) of samples with values in {0, 1}
    """
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Hyperparameters
    batch_size = 128
    learning_rate = 1e-3
    num_epochs = 10
    
    # Convert data to PyTorch tensors and scale to [-1, 1]
    train_data = torch.from_numpy(train_data).float().permute(0, 3, 1, 2).to(device)
    test_data = torch.from_numpy(test_data).float().permute(0, 3, 1, 2).to(device)
    
    # Create data loaders
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)
    
    # Initialize model, optimizer
    model = PixelCNN(in_channels=1, hidden_dim=64).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Lists to store metrics
    train_losses = []
    test_losses = []
    
    # Initial test loss
    model.eval()
    total_test_loss = 0
    with torch.no_grad():
        for data in test_loader:
            outputs = model(data)
            loss = negative_log_likelihood(outputs, data)
            total_test_loss += loss.item()
    initial_test_loss = total_test_loss / len(test_loader)
    test_losses.append(initial_test_loss)
    
    # Training loop
    for epoch in range(num_epochs):
        model.train()
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        
        for batch_idx, data in enumerate(progress_bar):
            optimizer.zero_grad()
            outputs = model(data)
            loss = binary_cross_entropy_loss(outputs, data)
            loss.backward()
            optimizer.step()
            
            # Track the average NLL per dimension
            nll = negative_log_likelihood(outputs, data)
            train_losses.append(nll.item())
            
            progress_bar.set_postfix({'Loss': nll.item()})
        
        # Evaluate on test set after each epoch
        model.eval()
        total_test_loss = 0
        with torch.no_grad():
            for data in test_loader:
                outputs = model(data)
                loss = negative_log_likelihood(outputs, data)
                total_test_loss += loss.item()
        epoch_test_loss = total_test_loss / len(test_loader)
        test_losses.append(epoch_test_loss)
        
        print(f'Epoch {epoch+1}: Test Loss: {epoch_test_loss:.6f}')
    
    # Generate samples
    samples = sample_from_model(model, image_shape, device, num_samples=100)
    samples = (samples > 0.5).astype(np.uint8)  # Convert probabilities to binary
    
    return np.array(train_losses), np.array(test_losses), samples

Results¶

Once you've implemented q2_a, execute the cells below to visualize and save your results

In [8]:
q2a_save_results(1, q2_a)
data_dir:  /home/nghiaph/workspace/deepul/homeworks/hw1/data
Epoch 1/10: 100%|██████████| 82/82 [00:01<00:00, 69.27it/s, Loss=62.2]
Epoch 1: Test Loss: 63.498842
Epoch 2/10: 100%|██████████| 82/82 [00:01<00:00, 76.21it/s, Loss=53.7]
Epoch 2: Test Loss: 56.255177
Epoch 3/10: 100%|██████████| 82/82 [00:01<00:00, 76.00it/s, Loss=43.2]
Epoch 3: Test Loss: 44.143267
Epoch 4/10: 100%|██████████| 82/82 [00:01<00:00, 75.82it/s, Loss=35.9]
Epoch 4: Test Loss: 35.370099
Epoch 5/10: 100%|██████████| 82/82 [00:01<00:00, 75.84it/s, Loss=31.9]
Epoch 5: Test Loss: 30.587291
Epoch 6/10: 100%|██████████| 82/82 [00:01<00:00, 75.87it/s, Loss=26.7]
Epoch 6: Test Loss: 27.376665
Epoch 7/10: 100%|██████████| 82/82 [00:01<00:00, 75.84it/s, Loss=25.1]
Epoch 7: Test Loss: 26.798818
Epoch 8/10: 100%|██████████| 82/82 [00:01<00:00, 76.25it/s, Loss=24.3]
Epoch 8: Test Loss: 23.934251
Epoch 9/10: 100%|██████████| 82/82 [00:01<00:00, 76.02it/s, Loss=21.8]
Epoch 9: Test Loss: 22.455027
Epoch 10/10: 100%|██████████| 82/82 [00:01<00:00, 76.16it/s, Loss=20.7]
Epoch 10: Test Loss: 20.706674
Final Test Loss: 20.7067
No description has been provided for this image
samples shape:  (100, 20, 20, 1)
No description has been provided for this image
In [9]:
q2a_save_results(2, q2_a)
data_dir:  /home/nghiaph/workspace/deepul/homeworks/hw1/data
Epoch 1/10: 100%|██████████| 469/469 [00:10<00:00, 43.04it/s, Loss=68.7]
Epoch 1: Test Loss: 69.982057
Epoch 2/10: 100%|██████████| 469/469 [00:10<00:00, 42.91it/s, Loss=68.3]
Epoch 2: Test Loss: 66.214365
Epoch 3/10: 100%|██████████| 469/469 [00:11<00:00, 42.53it/s, Loss=64.3]
Epoch 3: Test Loss: 65.107751
Epoch 4/10: 100%|██████████| 469/469 [00:11<00:00, 42.28it/s, Loss=62]  
Epoch 4: Test Loss: 63.785228
Epoch 5/10: 100%|██████████| 469/469 [00:11<00:00, 42.00it/s, Loss=65.2]
Epoch 5: Test Loss: 63.833508
Epoch 6/10: 100%|██████████| 469/469 [00:11<00:00, 42.03it/s, Loss=62.2]
Epoch 6: Test Loss: 62.760284
Epoch 7/10: 100%|██████████| 469/469 [00:11<00:00, 42.03it/s, Loss=64.5]
Epoch 7: Test Loss: 62.533142
Epoch 8/10: 100%|██████████| 469/469 [00:11<00:00, 42.61it/s, Loss=64.3]
Epoch 8: Test Loss: 62.928753
Epoch 9/10: 100%|██████████| 469/469 [00:10<00:00, 42.89it/s, Loss=61.7]
Epoch 9: Test Loss: 61.973298
Epoch 10/10: 100%|██████████| 469/469 [00:11<00:00, 42.39it/s, Loss=61]  
Epoch 10: Test Loss: 61.620466
Final Test Loss: 61.6205
No description has been provided for this image
samples shape:  (100, 28, 28, 1)
No description has been provided for this image

Part (b) PixelCNN on Colored Shapes and MNIST: Independent Color Channels¶

For the next part, we'll work with color images (shapes and MNIST). Run the cell below to visualize the dataset.

In [10]:
visualize_q2b_data(1)
visualize_q2b_data(2)
data_dir:  /home/nghiaph/workspace/deepul/homeworks/hw1/data
samples shape:  (100, 20, 20, 3)
No description has been provided for this image
data_dir:  /home/nghiaph/workspace/deepul/homeworks/hw1/data
samples shape:  (100, 28, 28, 3)
No description has been provided for this image

Now, implement a PixelCNN to support RGB color channels (or augment your existing implementation). First, implement a PixelCNN that assumes color channels as independent. More formally, we model the following parameterized distribution:

$$p_\theta(x) = \prod_{i=1}^{HW}\prod_{c=1}^C p_\theta(x_i^c | x_{<i})$$

Here are some tips that you may find useful for designing and training these models:

  • You will need a 4-way softmax for every prediction, as opposed to a 256-way softmax in the PixelCNN paper, since the dataset is quantized to two bits per color channel
  • You can set the number of filters for each convolutions to 120. You can use the ReLU nonlinearity throughout.
  • Use a stack of 8 residual block architecture from Figure 5 but with 7 x 7 masked convolutions in the middle instead of 3 x 3 masked convolutions
  • Consider using layer normalization to improve performance. However, be careful to maintain the autoregressive property.
  • With a learning rate of $10^{-3}$ and a batch size of 128, it should take a few minutes to run on the shapes dataset, and about 50-60 minutes on MNIST.

You will provide these deliverables

  1. Over the course of training, record the average negative log-likelihood (nats / dim) of the training data (per minibatch) and test data (for your entire test set). Code is provided that automatically plots the training curves.
  2. Report the final test set performance of your final model
  3. 100 samples from the final trained model

Fill out the function below and return the necessary arguments. Feel free to create more cells if need be.

In [11]:
import torch
import torch.optim as optim
import torch.utils.data as data
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

quiet = False

def train(model, train_loader, optimizer, epoch, grad_clip=None):
  model.train()
  
  train_losses = []
  for x in train_loader:
    x = x.cuda().contiguous()
    loss = model.loss(x)
    optimizer.zero_grad()
    loss.backward()
    if grad_clip:
      torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    optimizer.step()
    train_losses.append(loss.item())
  return train_losses

def eval_loss(model, data_loader):
  model.eval()
  total_loss = 0
  with torch.no_grad():
    for x in data_loader:
      x = x.cuda().contiguous()
      loss = model.loss(x)
      total_loss += loss * x.shape[0]
    avg_loss = total_loss / len(data_loader.dataset)

  return avg_loss.item()


def train_epochs(model, train_loader, test_loader, train_args):
  epochs, lr = train_args['epochs'], train_args['lr']
  grad_clip = train_args.get('grad_clip', None)
  optimizer = optim.Adam(model.parameters(), lr=lr)

  train_losses = []
  test_losses = [eval_loss(model, test_loader)]
  for epoch in range(epochs):
    model.train()
    train_losses.extend(train(model, train_loader, optimizer, epoch, grad_clip))
    test_loss = eval_loss(model, test_loader)
    test_losses.append(test_loss)
    if not quiet:
      print(f'Epoch {epoch}, Test loss {test_loss:.4f}')

  return train_losses, test_losses

class Histogram(nn.Module):
  def __init__(self, d):
    super().__init__()
    self.d = d
    self.logits = nn.Parameter(torch.zeros(d), requires_grad=True)

  def loss(self, x):
    logits = self.logits.unsqueeze(0).repeat(x.shape[0], 1) # batch_size x d
    return F.cross_entropy(logits, x.long())

  def get_distribution(self):
    distribution = F.softmax(self.logits, dim=0)
    return distribution.detach().cpu().numpy()
In [12]:
class MaskedConv2d(nn.Conv2d):
    """
    Implementation of Masked Convolution for PixelCNN
    """
    def __init__(self, mask_type, *args, **kwargs):
        assert mask_type == 'A' or mask_type == 'B'
        super().__init__(*args, **kwargs)
        self.register_buffer('mask', torch.zeros_like(self.weight))
        self.create_mask(mask_type)

    def forward(self, input):
        # Apply convolution with mask
        out = F.conv2d(input, self.weight * self.mask, self.bias, self.stride,
                      self.padding, self.dilation, self.groups)
        return out

    def create_mask(self, mask_type):
        # Get kernel size (assuming square kernel)
        k = self.kernel_size[0]
        
        # Set mask to 1 for all positions above the center
        self.mask[:, :, :k // 2] = 1
        
        # Set mask to 1 for positions to the left of center in the center row
        self.mask[:, :, k // 2, :k // 2] = 1
        
        # For type B masks, also set the center pixel to 1
        if mask_type == 'B':
            self.mask[:, :, k // 2, k // 2] = 1
    
class LayerNorm(nn.LayerNorm):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    def forward(self, x):
        x = x.permute(0, 2, 3, 1).contiguous()
        
        # Apply normalization (normalized_shape now matches)
        x = super().forward(x)
        
        # Permute back to [batch, channels, height, width]
        return x.permute(0, 3, 1, 2).contiguous()
  
class ResidualBlock(nn.Module):
    def __init__(self, n_filters, image_shape=None):
        super(ResidualBlock, self).__init__()
        # layer nornalization 
        H, W, _ = image_shape if image_shape is not None else (20, 20, 3) 
        self.layer_norm1 = LayerNorm(n_filters , H, W)
        self.layer_norm2 = LayerNorm(n_filters , H, W)
        self.layer_norm3 = LayerNorm(n_filters , H, W)

        # Main path
        self.conv1 = nn.Conv2d(n_filters, n_filters, kernel_size=1)
        self.conv2 = MaskedConv2d('B', n_filters, n_filters, kernel_size=7, padding=3)
        self.conv3 = nn.Conv2d(n_filters, n_filters, kernel_size=1)

        self.relu = nn.ReLU()
        
    def forward(self, x):
        # Store input for the skip connection
        
        identity = x
        
        # Main path
        out = self.layer_norm1(x)
        out = self.relu(out)
        out = self.conv1(out)
        
        out = self.layer_norm2(out)
        out = self.relu(out)
        out = self.conv2(out)
        
        out = self.layer_norm3(out)
        out = self.relu(out)   
        out = self.conv3(out)

        # Skip connection
        out += identity        
        return out
In [13]:
def compute_loss(logits, targets):
    """
    Compute cross-entropy loss for the PixelCNN model.
    
    Args:
        logits: Tensor of shape [batch_size, C, 4, H, W] - model predictions
        targets: Tensor of shape [batch_size, C, H, W] - ground truth values
        
    Returns:
        Average cross-entropy loss
    """
    logits_reshaped = logits.permute(0, 1, 3, 4, 2).reshape(-1, 4) # [batch_size, C, H, W 4]
    targets_reshaped = targets.reshape(-1).long()
    return F.cross_entropy(logits_reshaped, targets_reshaped)


def evaluate_model(model, data_loader, device):
    """
    Evaluate the model on a dataset
    """
    model.eval()
    total_loss = 0.0
    total_batches = 0
    
    with torch.no_grad():
        for (data,) in data_loader:
            data = data.to(device)
            total_batches += 1
            
            logits = model(data)
            loss = compute_loss(logits, data)
            total_loss += loss.item()
    
    return total_loss / total_batches



def generate_samples(model, num_samples, image_shape, device):
    """
    Generate samples from the model using ancestral sampling.
    Assumes the model expects float inputs representing integer values {0, 1, 2, 3}.
    """
    model.eval()
    H, W, C = image_shape
    
    # Initialize samples tensor with float type, as the model likely expects float inputs
    # even though the values represent discrete levels.
    samples = torch.zeros(num_samples, C, H, W, dtype=torch.float, device=device)
    temperature = 0.6
    with torch.no_grad():
        for h in range(H):
            for w in range(W):
                for c in range(C):
                    logits = model(samples) 
                    
                    pixel_logits = logits[:, c, :, h, w] # Shape: [num_samples, 4]
                    
                    probs = F.softmax(pixel_logits / temperature, dim=1)
                    
                    pixel_samples = torch.multinomial(probs, 1).squeeze(-1) # Shape: [num_samples]
                    
                    samples[:, c, h, w] = pixel_samples.float() 
    
    samples_np = samples.cpu().numpy().transpose(0, 2, 3, 1) 
    
    samples_np = samples_np.astype(np.uint8) 
    
    return samples_np
In [14]:
class PixelCNN(nn.Module):
    """
    PixelCNN model with masked convolutions
    """
    def __init__(self, image_shape, dset_id):
        super(PixelCNN, self).__init__()
        self.image_shape = image_shape
        self.dset_id = dset_id
        self.n_colors = 4
        # Number of input channels (1 for MNIST and shapes)
        in_channels = 3
        
        # Number of filters as specified in the assignment
        n_filters = 120
        
        # First layer: 7x7 masked type A convolution
        self.conv_A = MaskedConv2d('A', in_channels, n_filters, 7, padding=3, bias=True)

        # 8 layers of 7x7 masked type B convolutions
        self.residual_layers = nn.ModuleList([
            ResidualBlock(n_filters, image_shape)
            for _ in range(8)
        ])
        
        # 2 layers of 1x1 masked type B convolutions
        self.conv_B_1x1_layers = nn.ModuleList([
            MaskedConv2d('B', n_filters, n_filters, 1, padding=0, bias=True)
            for _ in range(2)
        ])
        
        # Output 4 logits for each of the 3 color channels
        self.output_conv = nn.Conv2d(n_filters, 4*3, 1, padding=0, bias=True)
        
        # ReLU activation
        self.relu = nn.ReLU()

    def forward(self, x):
        # Apply first mask A convolution
        x = (x.float() / (self.n_colors - 1) - 0.5) / 0.5
        x = self.relu(self.conv_A(x))
        
        # Apply mask B convolutions with ReLU activations
        for layer in self.residual_layers:
            x  = layer(x)
        
        # Apply 1x1 mask B convolutions with ReLU activations
        for layer in self.conv_B_1x1_layers:
            x = self.relu(layer(x))

        
        # Apply final convolution to get logits
        x = self.output_conv(x)


        batch_size, _, height, width = x.shape
        x = x.view(batch_size, 3,4,  height, width) 
        
        return x
    
In [17]:
def q2_b(train_data, test_data, image_shape, dset_id):
    """
    Trains a PixelCNN model for RGB images with 4 possible values per channel.
    
    Args:
        train_data: A (n_train, H, W, C) uint8 numpy array of color images with values in {0, 1, 2, 3}
        test_data: A (n_test, H, W, C) uint8 numpy array of color images with values in {0, 1, 2, 3}
        image_shape: (H, W, C), height, width, and # of channels of the image
        dset_id: An identifying number of which dataset is given (1 or 2)
                 Used to set different hyperparameters for different datasets

    Returns:
        - train_losses: A (# of training iterations,) numpy array of per-batch training losses
        - test_losses: A (# of epochs + 1,) numpy array of test losses after each epoch (including initialization)
        - samples: A (100, H, W, C) numpy array of generated samples with values in {0, 1, 2, 3}
    """ 
    # Hyperparameters
    batch_size = 128 
    learning_rate = 0.001 * np.sqrt(batch_size / 128)
    num_epochs = 20
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Data preparation
    train_data_tensor = torch.FloatTensor(train_data).permute(0, 3, 1, 2)
    test_data_tensor = torch.FloatTensor(test_data).permute(0, 3, 1, 2)
    
    train_dataset = torch.utils.data.TensorDataset(train_data_tensor)
    test_dataset = torch.utils.data.TensorDataset(test_data_tensor)
    
    train_loader = torch.utils.data.DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=2, 
        pin_memory=(device.type == 'cuda')
    )
    
    test_loader = torch.utils.data.DataLoader(
        test_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=2, 
        pin_memory=(device.type == 'cuda')
    )
    
    # Model initialization
    model = PixelCNN(image_shape, dset_id).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',        # Reduce LR when the monitored metric stops decreasing
        factor=0.5,        # Multiply learning rate by this factor when reducing
        patience=2,        # Number of epochs with no improvement after which LR will be reduced
        # verbose=True,      # Print message when LR is reduced
        min_lr=1e-6        # Lower bound on the learning rate
    )
    
    # Initialize gradient scaler for mixed precision training
    scaler = torch.cuda.amp.GradScaler()

    # Early stopping parameters
    best_loss = float('inf')
    best_model_state = None
    patience = 5
    patience_counter = 0

    # Loss tracking
    train_losses = []
    test_losses = []
    
    # Initial evaluation
    init_test_loss = evaluate_model(model, test_loader, device)
    test_losses.append(init_test_loss)
    print(f"Initial test loss: {init_test_loss:.6f}")
    
    # Training loop
    for epoch in range(num_epochs):
        model.train()
        epoch_train_losses = []
        
        for batch_idx, (data,) in enumerate(train_loader):
            data = data.to(device)
            optimizer.zero_grad()
            
            # Use autocast for mixed precision training
            with torch.cuda.amp.autocast():
                logits = model(data)
                loss = compute_loss(logits, data)

            # Scale the loss and backpropagate
            scaler.scale(loss).backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)  # Reduced from 1.0 for stability
            
            # Update weights with scaled gradients
            scaler.step(optimizer)
            scaler.update()
            
            loss_value = loss.item()
            train_losses.append(loss_value)
            epoch_train_losses.append(loss_value)
            
            if batch_idx % 10 == 0:
                print(f'Epoch: {epoch+1}/{num_epochs}, Batch: {batch_idx}/{len(train_loader)}, '
                      f'Loss: {loss_value:.6f}')
        
        avg_train_loss = sum(epoch_train_losses) / len(epoch_train_losses)
        print(f'Epoch {epoch+1} average training loss: {avg_train_loss:.6f}')
        
        # Evaluate the model
        test_loss = evaluate_model(model, test_loader, device)
        test_losses.append(test_loss)
        print(f'Epoch {epoch+1} test loss: {test_loss:.6f}')
        
        # Update learning rate based on test loss
        scheduler.step(test_loss)
        
        # Print current learning rate (correct way for ReduceLROnPlateau)
        print(f'Current learning rate: {optimizer.param_groups[0]["lr"]:.6f}')
        
        # Early stopping check
        if test_loss < best_loss:
            best_loss = test_loss
            best_model_state = copy.deepcopy(model.state_dict())
            patience_counter = 0
            print(f"New best model with test loss: {best_loss:.6f}")
        else:
            patience_counter += 1
            print(f"No improvement for {patience_counter} epochs")
            
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
    
    # Load the best model for sampling
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"Loaded best model with test loss: {best_loss:.6f}")
    
    # Generate samples with improved sampling
    samples = generate_samples(model, 100, image_shape, device)
    
    return np.array(train_losses), np.array(test_losses), samples


def evaluate_model(model, data_loader, device):
    """
    Evaluate the model on a dataset
    """
    model.eval()
    total_loss = 0.0
    total_batches = 0
    
    with torch.no_grad():
        for (data,) in data_loader:
            data = data.to(device)
            total_batches += 1
            
            with torch.cuda.amp.autocast():  # Use autocast for evaluation too
                logits = model(data)
                loss = compute_loss(logits, data)
                
            total_loss += loss.item()
    
    return total_loss / total_batches

Results¶

Once you've implemented q2_b, execute the cells below to visualize and save your results

In [18]:
q2b_save_results(1, 'b', q2_b)
data_dir:  /home/nghiaph/workspace/deepul/homeworks/hw1/data
Using device: cuda
/tmp/ipykernel_2361820/2857294626.py:61: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
  scaler = torch.cuda.amp.GradScaler()
/tmp/ipykernel_2361820/2857294626.py:162: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with torch.cuda.amp.autocast():  # Use autocast for evaluation too
Initial test loss: 1.388850
/tmp/ipykernel_2361820/2857294626.py:88: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with torch.cuda.amp.autocast():
Epoch: 1/20, Batch: 0/82, Loss: 1.391449
Epoch: 1/20, Batch: 10/82, Loss: 1.329665
Epoch: 1/20, Batch: 20/82, Loss: 1.191522
Epoch: 1/20, Batch: 30/82, Loss: 0.999146
Epoch: 1/20, Batch: 40/82, Loss: 0.806963
Epoch: 1/20, Batch: 50/82, Loss: 0.667744
Epoch: 1/20, Batch: 60/82, Loss: 0.571827
Epoch: 1/20, Batch: 70/82, Loss: 0.516736
Epoch: 1/20, Batch: 80/82, Loss: 0.472714
Epoch 1 average training loss: 0.874104
Epoch 1 test loss: 0.472679
Current learning rate: 0.001000
New best model with test loss: 0.472679
Epoch: 2/20, Batch: 0/82, Loss: 0.465338
Epoch: 2/20, Batch: 10/82, Loss: 0.428478
Epoch: 2/20, Batch: 20/82, Loss: 0.404972
Epoch: 2/20, Batch: 30/82, Loss: 0.361007
Epoch: 2/20, Batch: 40/82, Loss: 0.347785
Epoch: 2/20, Batch: 50/82, Loss: 0.336060
Epoch: 2/20, Batch: 60/82, Loss: 0.321864
Epoch: 2/20, Batch: 70/82, Loss: 0.314387
Epoch: 2/20, Batch: 80/82, Loss: 0.283376
Epoch 2 average training loss: 0.361904
Epoch 2 test loss: 0.286814
Current learning rate: 0.001000
New best model with test loss: 0.286814
Epoch: 3/20, Batch: 0/82, Loss: 0.276777
Epoch: 3/20, Batch: 10/82, Loss: 0.266449
Epoch: 3/20, Batch: 20/82, Loss: 0.255323
Epoch: 3/20, Batch: 30/82, Loss: 0.240444
Epoch: 3/20, Batch: 40/82, Loss: 0.226757
Epoch: 3/20, Batch: 50/82, Loss: 0.205264
Epoch: 3/20, Batch: 60/82, Loss: 0.209912
Epoch: 3/20, Batch: 70/82, Loss: 0.184275
Epoch: 3/20, Batch: 80/82, Loss: 0.179151
Epoch 3 average training loss: 0.226909
Epoch 3 test loss: 0.184995
Current learning rate: 0.001000
New best model with test loss: 0.184995
Epoch: 4/20, Batch: 0/82, Loss: 0.179256
Epoch: 4/20, Batch: 10/82, Loss: 0.172584
Epoch: 4/20, Batch: 20/82, Loss: 0.173224
Epoch: 4/20, Batch: 30/82, Loss: 0.155870
Epoch: 4/20, Batch: 40/82, Loss: 0.151243
Epoch: 4/20, Batch: 50/82, Loss: 0.149168
Epoch: 4/20, Batch: 60/82, Loss: 0.149802
Epoch: 4/20, Batch: 70/82, Loss: 0.127425
Epoch: 4/20, Batch: 80/82, Loss: 0.131133
Epoch 4 average training loss: 0.154246
Epoch 4 test loss: 0.131514
Current learning rate: 0.001000
New best model with test loss: 0.131514
Epoch: 5/20, Batch: 0/82, Loss: 0.127266
Epoch: 5/20, Batch: 10/82, Loss: 0.127639
Epoch: 5/20, Batch: 20/82, Loss: 0.117048
Epoch: 5/20, Batch: 30/82, Loss: 0.126710
Epoch: 5/20, Batch: 40/82, Loss: 0.118495
Epoch: 5/20, Batch: 50/82, Loss: 0.118372
Epoch: 5/20, Batch: 60/82, Loss: 0.114484
Epoch: 5/20, Batch: 70/82, Loss: 0.110210
Epoch: 5/20, Batch: 80/82, Loss: 0.111262
Epoch 5 average training loss: 0.119808
Epoch 5 test loss: 0.112343
Current learning rate: 0.001000
New best model with test loss: 0.112343
Epoch: 6/20, Batch: 0/82, Loss: 0.107689
Epoch: 6/20, Batch: 10/82, Loss: 0.106192
Epoch: 6/20, Batch: 20/82, Loss: 0.107759
Epoch: 6/20, Batch: 30/82, Loss: 0.111071
Epoch: 6/20, Batch: 40/82, Loss: 0.108741
Epoch: 6/20, Batch: 50/82, Loss: 0.108976
Epoch: 6/20, Batch: 60/82, Loss: 0.110752
Epoch: 6/20, Batch: 70/82, Loss: 0.104242
Epoch: 6/20, Batch: 80/82, Loss: 0.105161
Epoch 6 average training loss: 0.107490
Epoch 6 test loss: 0.103943
Current learning rate: 0.001000
New best model with test loss: 0.103943
Epoch: 7/20, Batch: 0/82, Loss: 0.104092
Epoch: 7/20, Batch: 10/82, Loss: 0.107277
Epoch: 7/20, Batch: 20/82, Loss: 0.103453
Epoch: 7/20, Batch: 30/82, Loss: 0.101431
Epoch: 7/20, Batch: 40/82, Loss: 0.101527
Epoch: 7/20, Batch: 50/82, Loss: 0.096972
Epoch: 7/20, Batch: 60/82, Loss: 0.098578
Epoch: 7/20, Batch: 70/82, Loss: 0.099638
Epoch: 7/20, Batch: 80/82, Loss: 0.099156
Epoch 7 average training loss: 0.100878
Epoch 7 test loss: 0.098755
Current learning rate: 0.001000
New best model with test loss: 0.098755
Epoch: 8/20, Batch: 0/82, Loss: 0.097980
Epoch: 8/20, Batch: 10/82, Loss: 0.099518
Epoch: 8/20, Batch: 20/82, Loss: 0.096265
Epoch: 8/20, Batch: 30/82, Loss: 0.101945
Epoch: 8/20, Batch: 40/82, Loss: 0.096464
Epoch: 8/20, Batch: 50/82, Loss: 0.097598
Epoch: 8/20, Batch: 60/82, Loss: 0.096989
Epoch: 8/20, Batch: 70/82, Loss: 0.097894
Epoch: 8/20, Batch: 80/82, Loss: 0.093489
Epoch 8 average training loss: 0.097322
Epoch 8 test loss: 0.096321
Current learning rate: 0.001000
New best model with test loss: 0.096321
Epoch: 9/20, Batch: 0/82, Loss: 0.097640
Epoch: 9/20, Batch: 10/82, Loss: 0.096712
Epoch: 9/20, Batch: 20/82, Loss: 0.099051
Epoch: 9/20, Batch: 30/82, Loss: 0.094936
Epoch: 9/20, Batch: 40/82, Loss: 0.091105
Epoch: 9/20, Batch: 50/82, Loss: 0.088823
Epoch: 9/20, Batch: 60/82, Loss: 0.090040
Epoch: 9/20, Batch: 70/82, Loss: 0.093468
Epoch: 9/20, Batch: 80/82, Loss: 0.096179
Epoch 9 average training loss: 0.094821
Epoch 9 test loss: 0.094621
Current learning rate: 0.001000
New best model with test loss: 0.094621
Epoch: 10/20, Batch: 0/82, Loss: 0.094129
Epoch: 10/20, Batch: 10/82, Loss: 0.094922
Epoch: 10/20, Batch: 20/82, Loss: 0.089785
Epoch: 10/20, Batch: 30/82, Loss: 0.094372
Epoch: 10/20, Batch: 40/82, Loss: 0.089206
Epoch: 10/20, Batch: 50/82, Loss: 0.094630
Epoch: 10/20, Batch: 60/82, Loss: 0.094062
Epoch: 10/20, Batch: 70/82, Loss: 0.089265
Epoch: 10/20, Batch: 80/82, Loss: 0.090081
Epoch 10 average training loss: 0.093107
Epoch 10 test loss: 0.091463
Current learning rate: 0.001000
New best model with test loss: 0.091463
Epoch: 11/20, Batch: 0/82, Loss: 0.090736
Epoch: 11/20, Batch: 10/82, Loss: 0.088854
Epoch: 11/20, Batch: 20/82, Loss: 0.092054
Epoch: 11/20, Batch: 30/82, Loss: 0.092395
Epoch: 11/20, Batch: 40/82, Loss: 0.090031
Epoch: 11/20, Batch: 50/82, Loss: 0.090507
Epoch: 11/20, Batch: 60/82, Loss: 0.094138
Epoch: 11/20, Batch: 70/82, Loss: 0.091632
Epoch: 11/20, Batch: 80/82, Loss: 0.090938
Epoch 11 average training loss: 0.090473
Epoch 11 test loss: 0.089798
Current learning rate: 0.001000
New best model with test loss: 0.089798
Epoch: 12/20, Batch: 0/82, Loss: 0.088649
Epoch: 12/20, Batch: 10/82, Loss: 0.087368
Epoch: 12/20, Batch: 20/82, Loss: 0.083911
Epoch: 12/20, Batch: 30/82, Loss: 0.090002
Epoch: 12/20, Batch: 40/82, Loss: 0.094381
Epoch: 12/20, Batch: 50/82, Loss: 0.085193
Epoch: 12/20, Batch: 60/82, Loss: 0.085929
Epoch: 12/20, Batch: 70/82, Loss: 0.085538
Epoch: 12/20, Batch: 80/82, Loss: 0.087217
Epoch 12 average training loss: 0.088904
Epoch 12 test loss: 0.087783
Current learning rate: 0.001000
New best model with test loss: 0.087783
Epoch: 13/20, Batch: 0/82, Loss: 0.083210
Epoch: 13/20, Batch: 10/82, Loss: 0.090468
Epoch: 13/20, Batch: 20/82, Loss: 0.089873
Epoch: 13/20, Batch: 30/82, Loss: 0.085177
Epoch: 13/20, Batch: 40/82, Loss: 0.085015
Epoch: 13/20, Batch: 50/82, Loss: 0.084958
Epoch: 13/20, Batch: 60/82, Loss: 0.087870
Epoch: 13/20, Batch: 70/82, Loss: 0.085928
Epoch: 13/20, Batch: 80/82, Loss: 0.081580
Epoch 13 average training loss: 0.087305
Epoch 13 test loss: 0.087037
Current learning rate: 0.001000
New best model with test loss: 0.087037
Epoch: 14/20, Batch: 0/82, Loss: 0.084078
Epoch: 14/20, Batch: 10/82, Loss: 0.090743
Epoch: 14/20, Batch: 20/82, Loss: 0.088010
Epoch: 14/20, Batch: 30/82, Loss: 0.084594
Epoch: 14/20, Batch: 40/82, Loss: 0.088533
Epoch: 14/20, Batch: 50/82, Loss: 0.081967
Epoch: 14/20, Batch: 60/82, Loss: 0.087185
Epoch: 14/20, Batch: 70/82, Loss: 0.088961
Epoch: 14/20, Batch: 80/82, Loss: 0.082240
Epoch 14 average training loss: 0.086166
Epoch 14 test loss: 0.086303
Current learning rate: 0.001000
New best model with test loss: 0.086303
Epoch: 15/20, Batch: 0/82, Loss: 0.083607
Epoch: 15/20, Batch: 10/82, Loss: 0.084251
Epoch: 15/20, Batch: 20/82, Loss: 0.089648
Epoch: 15/20, Batch: 30/82, Loss: 0.082188
Epoch: 15/20, Batch: 40/82, Loss: 0.087374
Epoch: 15/20, Batch: 50/82, Loss: 0.087053
Epoch: 15/20, Batch: 60/82, Loss: 0.091182
Epoch: 15/20, Batch: 70/82, Loss: 0.082506
Epoch: 15/20, Batch: 80/82, Loss: 0.092663
Epoch 15 average training loss: 0.086581
Epoch 15 test loss: 0.088489
Current learning rate: 0.001000
No improvement for 1 epochs
Epoch: 16/20, Batch: 0/82, Loss: 0.083316
Epoch: 16/20, Batch: 10/82, Loss: 0.083976
Epoch: 16/20, Batch: 20/82, Loss: 0.081897
Epoch: 16/20, Batch: 30/82, Loss: 0.080853
Epoch: 16/20, Batch: 40/82, Loss: 0.082935
Epoch: 16/20, Batch: 50/82, Loss: 0.080993
Epoch: 16/20, Batch: 60/82, Loss: 0.087351
Epoch: 16/20, Batch: 70/82, Loss: 0.080957
Epoch: 16/20, Batch: 80/82, Loss: 0.089425
Epoch 16 average training loss: 0.085030
Epoch 16 test loss: 0.086015
Current learning rate: 0.001000
New best model with test loss: 0.086015
Epoch: 17/20, Batch: 0/82, Loss: 0.082371
Epoch: 17/20, Batch: 10/82, Loss: 0.087590
Epoch: 17/20, Batch: 20/82, Loss: 0.083791
Epoch: 17/20, Batch: 30/82, Loss: 0.080405
Epoch: 17/20, Batch: 40/82, Loss: 0.085472
Epoch: 17/20, Batch: 50/82, Loss: 0.087293
Epoch: 17/20, Batch: 60/82, Loss: 0.085169
Epoch: 17/20, Batch: 70/82, Loss: 0.085010
Epoch: 17/20, Batch: 80/82, Loss: 0.092598
Epoch 17 average training loss: 0.085812
Epoch 17 test loss: 0.086086
Current learning rate: 0.001000
No improvement for 1 epochs
Epoch: 18/20, Batch: 0/82, Loss: 0.085155
Epoch: 18/20, Batch: 10/82, Loss: 0.087516
Epoch: 18/20, Batch: 20/82, Loss: 0.084148
Epoch: 18/20, Batch: 30/82, Loss: 0.085064
Epoch: 18/20, Batch: 40/82, Loss: 0.092279
Epoch: 18/20, Batch: 50/82, Loss: 0.081455
Epoch: 18/20, Batch: 60/82, Loss: 0.087196
Epoch: 18/20, Batch: 70/82, Loss: 0.083074
Epoch: 18/20, Batch: 80/82, Loss: 0.083386
Epoch 18 average training loss: 0.085906
Epoch 18 test loss: 0.090479
Current learning rate: 0.001000
No improvement for 2 epochs
Epoch: 19/20, Batch: 0/82, Loss: 0.091709
Epoch: 19/20, Batch: 10/82, Loss: 0.083963
Epoch: 19/20, Batch: 20/82, Loss: 0.083727
Epoch: 19/20, Batch: 30/82, Loss: 0.085782
Epoch: 19/20, Batch: 40/82, Loss: 0.083908
Epoch: 19/20, Batch: 50/82, Loss: 0.086676
Epoch: 19/20, Batch: 60/82, Loss: 0.086588
Epoch: 19/20, Batch: 70/82, Loss: 0.082210
Epoch: 19/20, Batch: 80/82, Loss: 0.087303
Epoch 19 average training loss: 0.085502
Epoch 19 test loss: 0.085137
Current learning rate: 0.001000
New best model with test loss: 0.085137
Epoch: 20/20, Batch: 0/82, Loss: 0.086982
Epoch: 20/20, Batch: 10/82, Loss: 0.088542
Epoch: 20/20, Batch: 20/82, Loss: 0.090561
Epoch: 20/20, Batch: 30/82, Loss: 0.078148
Epoch: 20/20, Batch: 40/82, Loss: 0.080311
Epoch: 20/20, Batch: 50/82, Loss: 0.090522
Epoch: 20/20, Batch: 60/82, Loss: 0.081093
Epoch: 20/20, Batch: 70/82, Loss: 0.078895
Epoch: 20/20, Batch: 80/82, Loss: 0.088900
Epoch 20 average training loss: 0.085175
Epoch 20 test loss: 0.084870
Current learning rate: 0.001000
New best model with test loss: 0.084870
Loaded best model with test loss: 0.084870
Final Test Loss: 0.0849
No description has been provided for this image
samples shape:  (100, 20, 20, 3)
No description has been provided for this image
In [19]:
q2b_save_results(2, 'b', q2_b)
data_dir:  /home/nghiaph/workspace/deepul/homeworks/hw1/data
Using device: cuda
/tmp/ipykernel_2361820/2857294626.py:61: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
  scaler = torch.cuda.amp.GradScaler()
/tmp/ipykernel_2361820/2857294626.py:162: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with torch.cuda.amp.autocast():  # Use autocast for evaluation too
Initial test loss: 1.365160
/tmp/ipykernel_2361820/2857294626.py:88: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with torch.cuda.amp.autocast():
Epoch: 1/20, Batch: 0/469, Loss: 1.364236
Epoch: 1/20, Batch: 10/469, Loss: 1.236289
Epoch: 1/20, Batch: 20/469, Loss: 0.936831
Epoch: 1/20, Batch: 30/469, Loss: 0.858414
Epoch: 1/20, Batch: 40/469, Loss: 0.776309
Epoch: 1/20, Batch: 50/469, Loss: 0.680986
Epoch: 1/20, Batch: 60/469, Loss: 0.598954
Epoch: 1/20, Batch: 70/469, Loss: 0.516625
Epoch: 1/20, Batch: 80/469, Loss: 0.491731
Epoch: 1/20, Batch: 90/469, Loss: 0.435464
Epoch: 1/20, Batch: 100/469, Loss: 0.415777
Epoch: 1/20, Batch: 110/469, Loss: 0.375786
Epoch: 1/20, Batch: 120/469, Loss: 0.360483
Epoch: 1/20, Batch: 130/469, Loss: 0.322623
Epoch: 1/20, Batch: 140/469, Loss: 0.319569
Epoch: 1/20, Batch: 150/469, Loss: 0.290041
Epoch: 1/20, Batch: 160/469, Loss: 0.281337
Epoch: 1/20, Batch: 170/469, Loss: 0.275442
Epoch: 1/20, Batch: 180/469, Loss: 0.276323
Epoch: 1/20, Batch: 190/469, Loss: 0.266604
Epoch: 1/20, Batch: 200/469, Loss: 0.266382
Epoch: 1/20, Batch: 210/469, Loss: 0.264654
Epoch: 1/20, Batch: 220/469, Loss: 0.260019
Epoch: 1/20, Batch: 230/469, Loss: 0.266150
Epoch: 1/20, Batch: 240/469, Loss: 0.273089
Epoch: 1/20, Batch: 250/469, Loss: 0.256204
Epoch: 1/20, Batch: 260/469, Loss: 0.263772
Epoch: 1/20, Batch: 270/469, Loss: 0.261555
Epoch: 1/20, Batch: 280/469, Loss: 0.258613
Epoch: 1/20, Batch: 290/469, Loss: 0.263624
Epoch: 1/20, Batch: 300/469, Loss: 0.269133
Epoch: 1/20, Batch: 310/469, Loss: 0.263299
Epoch: 1/20, Batch: 320/469, Loss: 0.261412
Epoch: 1/20, Batch: 330/469, Loss: 0.253682
Epoch: 1/20, Batch: 340/469, Loss: 0.261186
Epoch: 1/20, Batch: 350/469, Loss: 0.261738
Epoch: 1/20, Batch: 360/469, Loss: 0.253270
Epoch: 1/20, Batch: 370/469, Loss: 0.254910
Epoch: 1/20, Batch: 380/469, Loss: 0.251309
Epoch: 1/20, Batch: 390/469, Loss: 0.250324
Epoch: 1/20, Batch: 400/469, Loss: 0.249594
Epoch: 1/20, Batch: 410/469, Loss: 0.262300
Epoch: 1/20, Batch: 420/469, Loss: 0.249880
Epoch: 1/20, Batch: 430/469, Loss: 0.239845
Epoch: 1/20, Batch: 440/469, Loss: 0.252041
Epoch: 1/20, Batch: 450/469, Loss: 0.255249
Epoch: 1/20, Batch: 460/469, Loss: 0.263435
Epoch 1 average training loss: 0.374674
Epoch 1 test loss: 0.251673
Current learning rate: 0.001000
New best model with test loss: 0.251673
Epoch: 2/20, Batch: 0/469, Loss: 0.255822
Epoch: 2/20, Batch: 10/469, Loss: 0.259323
Epoch: 2/20, Batch: 20/469, Loss: 0.267333
Epoch: 2/20, Batch: 30/469, Loss: 0.253701
Epoch: 2/20, Batch: 40/469, Loss: 0.259780
Epoch: 2/20, Batch: 50/469, Loss: 0.250533
Epoch: 2/20, Batch: 60/469, Loss: 0.252443
Epoch: 2/20, Batch: 70/469, Loss: 0.266876
Epoch: 2/20, Batch: 80/469, Loss: 0.235499
Epoch: 2/20, Batch: 90/469, Loss: 0.248745
Epoch: 2/20, Batch: 100/469, Loss: 0.244379
Epoch: 2/20, Batch: 110/469, Loss: 0.249996
Epoch: 2/20, Batch: 120/469, Loss: 0.246507
Epoch: 2/20, Batch: 130/469, Loss: 0.249283
Epoch: 2/20, Batch: 140/469, Loss: 0.253623
Epoch: 2/20, Batch: 150/469, Loss: 0.249731
Epoch: 2/20, Batch: 160/469, Loss: 0.252728
Epoch: 2/20, Batch: 170/469, Loss: 0.245919
Epoch: 2/20, Batch: 180/469, Loss: 0.246375
Epoch: 2/20, Batch: 190/469, Loss: 0.253191
Epoch: 2/20, Batch: 200/469, Loss: 0.246197
Epoch: 2/20, Batch: 210/469, Loss: 0.248491
Epoch: 2/20, Batch: 220/469, Loss: 0.246431
Epoch: 2/20, Batch: 230/469, Loss: 0.248911
Epoch: 2/20, Batch: 240/469, Loss: 0.245220
Epoch: 2/20, Batch: 250/469, Loss: 0.246489
Epoch: 2/20, Batch: 260/469, Loss: 0.249481
Epoch: 2/20, Batch: 270/469, Loss: 0.246066
Epoch: 2/20, Batch: 280/469, Loss: 0.246358
Epoch: 2/20, Batch: 290/469, Loss: 0.247223
Epoch: 2/20, Batch: 300/469, Loss: 0.249195
Epoch: 2/20, Batch: 310/469, Loss: 0.247354
Epoch: 2/20, Batch: 320/469, Loss: 0.254715
Epoch: 2/20, Batch: 330/469, Loss: 0.242449
Epoch: 2/20, Batch: 340/469, Loss: 0.244332
Epoch: 2/20, Batch: 350/469, Loss: 0.248880
Epoch: 2/20, Batch: 360/469, Loss: 0.242313
Epoch: 2/20, Batch: 370/469, Loss: 0.233009
Epoch: 2/20, Batch: 380/469, Loss: 0.240140
Epoch: 2/20, Batch: 390/469, Loss: 0.231701
Epoch: 2/20, Batch: 400/469, Loss: 0.239487
Epoch: 2/20, Batch: 410/469, Loss: 0.244440
Epoch: 2/20, Batch: 420/469, Loss: 0.241557
Epoch: 2/20, Batch: 430/469, Loss: 0.247809
Epoch: 2/20, Batch: 440/469, Loss: 0.244429
Epoch: 2/20, Batch: 450/469, Loss: 0.236603
Epoch: 2/20, Batch: 460/469, Loss: 0.243593
Epoch 2 average training loss: 0.247342
Epoch 2 test loss: 0.241083
Current learning rate: 0.001000
New best model with test loss: 0.241083
Epoch: 3/20, Batch: 0/469, Loss: 0.239596
Epoch: 3/20, Batch: 10/469, Loss: 0.248777
Epoch: 3/20, Batch: 20/469, Loss: 0.245636
Epoch: 3/20, Batch: 30/469, Loss: 0.244686
Epoch: 3/20, Batch: 40/469, Loss: 0.239712
Epoch: 3/20, Batch: 50/469, Loss: 0.239452
Epoch: 3/20, Batch: 60/469, Loss: 0.244311
Epoch: 3/20, Batch: 70/469, Loss: 0.246592
Epoch: 3/20, Batch: 80/469, Loss: 0.230436
Epoch: 3/20, Batch: 90/469, Loss: 0.232847
Epoch: 3/20, Batch: 100/469, Loss: 0.238146
Epoch: 3/20, Batch: 110/469, Loss: 0.247689
Epoch: 3/20, Batch: 120/469, Loss: 0.239976
Epoch: 3/20, Batch: 130/469, Loss: 0.240102
Epoch: 3/20, Batch: 140/469, Loss: 0.236300
Epoch: 3/20, Batch: 150/469, Loss: 0.237734
Epoch: 3/20, Batch: 160/469, Loss: 0.231506
Epoch: 3/20, Batch: 170/469, Loss: 0.233169
Epoch: 3/20, Batch: 180/469, Loss: 0.241323
Epoch: 3/20, Batch: 190/469, Loss: 0.239814
Epoch: 3/20, Batch: 200/469, Loss: 0.240700
Epoch: 3/20, Batch: 210/469, Loss: 0.236182
Epoch: 3/20, Batch: 220/469, Loss: 0.237793
Epoch: 3/20, Batch: 230/469, Loss: 0.234612
Epoch: 3/20, Batch: 240/469, Loss: 0.234325
Epoch: 3/20, Batch: 250/469, Loss: 0.230192
Epoch: 3/20, Batch: 260/469, Loss: 0.240093
Epoch: 3/20, Batch: 270/469, Loss: 0.234376
Epoch: 3/20, Batch: 280/469, Loss: 0.233690
Epoch: 3/20, Batch: 290/469, Loss: 0.233709
Epoch: 3/20, Batch: 300/469, Loss: 0.239333
Epoch: 3/20, Batch: 310/469, Loss: 0.232038
Epoch: 3/20, Batch: 320/469, Loss: 0.230965
Epoch: 3/20, Batch: 330/469, Loss: 0.234805
Epoch: 3/20, Batch: 340/469, Loss: 0.229885
Epoch: 3/20, Batch: 350/469, Loss: 0.233624
Epoch: 3/20, Batch: 360/469, Loss: 0.236268
Epoch: 3/20, Batch: 370/469, Loss: 0.234259
Epoch: 3/20, Batch: 380/469, Loss: 0.229853
Epoch: 3/20, Batch: 390/469, Loss: 0.235964
Epoch: 3/20, Batch: 400/469, Loss: 0.227787
Epoch: 3/20, Batch: 410/469, Loss: 0.240188
Epoch: 3/20, Batch: 420/469, Loss: 0.226364
Epoch: 3/20, Batch: 430/469, Loss: 0.234578
Epoch: 3/20, Batch: 440/469, Loss: 0.233978
Epoch: 3/20, Batch: 450/469, Loss: 0.233747
Epoch: 3/20, Batch: 460/469, Loss: 0.238040
Epoch 3 average training loss: 0.236651
Epoch 3 test loss: 0.230761
Current learning rate: 0.001000
New best model with test loss: 0.230761
Epoch: 4/20, Batch: 0/469, Loss: 0.234140
Epoch: 4/20, Batch: 10/469, Loss: 0.232589
Epoch: 4/20, Batch: 20/469, Loss: 0.231800
Epoch: 4/20, Batch: 30/469, Loss: 0.235732
Epoch: 4/20, Batch: 40/469, Loss: 0.229330
Epoch: 4/20, Batch: 50/469, Loss: 0.231006
Epoch: 4/20, Batch: 60/469, Loss: 0.224475
Epoch: 4/20, Batch: 70/469, Loss: 0.226936
Epoch: 4/20, Batch: 80/469, Loss: 0.233463
Epoch: 4/20, Batch: 90/469, Loss: 0.235182
Epoch: 4/20, Batch: 100/469, Loss: 0.233420
Epoch: 4/20, Batch: 110/469, Loss: 0.233781
Epoch: 4/20, Batch: 120/469, Loss: 0.228672
Epoch: 4/20, Batch: 130/469, Loss: 0.235871
Epoch: 4/20, Batch: 140/469, Loss: 0.223315
Epoch: 4/20, Batch: 150/469, Loss: 0.232540
Epoch: 4/20, Batch: 160/469, Loss: 0.222886
Epoch: 4/20, Batch: 170/469, Loss: 0.229659
Epoch: 4/20, Batch: 180/469, Loss: 0.227260
Epoch: 4/20, Batch: 190/469, Loss: 0.229968
Epoch: 4/20, Batch: 200/469, Loss: 0.221816
Epoch: 4/20, Batch: 210/469, Loss: 0.230529
Epoch: 4/20, Batch: 220/469, Loss: 0.221697
Epoch: 4/20, Batch: 230/469, Loss: 0.234514
Epoch: 4/20, Batch: 240/469, Loss: 0.226947
Epoch: 4/20, Batch: 250/469, Loss: 0.230282
Epoch: 4/20, Batch: 260/469, Loss: 0.236895
Epoch: 4/20, Batch: 270/469, Loss: 0.228116
Epoch: 4/20, Batch: 280/469, Loss: 0.224220
Epoch: 4/20, Batch: 290/469, Loss: 0.227844
Epoch: 4/20, Batch: 300/469, Loss: 0.223508
Epoch: 4/20, Batch: 310/469, Loss: 0.221039
Epoch: 4/20, Batch: 320/469, Loss: 0.228566
Epoch: 4/20, Batch: 330/469, Loss: 0.229608
Epoch: 4/20, Batch: 340/469, Loss: 0.229945
Epoch: 4/20, Batch: 350/469, Loss: 0.226479
Epoch: 4/20, Batch: 360/469, Loss: 0.220995
Epoch: 4/20, Batch: 370/469, Loss: 0.224024
Epoch: 4/20, Batch: 380/469, Loss: 0.228491
Epoch: 4/20, Batch: 390/469, Loss: 0.221185
Epoch: 4/20, Batch: 400/469, Loss: 0.226422
Epoch: 4/20, Batch: 410/469, Loss: 0.231617
Epoch: 4/20, Batch: 420/469, Loss: 0.217751
Epoch: 4/20, Batch: 430/469, Loss: 0.227823
Epoch: 4/20, Batch: 440/469, Loss: 0.220278
Epoch: 4/20, Batch: 450/469, Loss: 0.230993
Epoch: 4/20, Batch: 460/469, Loss: 0.229055
Epoch 4 average training loss: 0.228498
Epoch 4 test loss: 0.224772
Current learning rate: 0.001000
New best model with test loss: 0.224772
Epoch: 5/20, Batch: 0/469, Loss: 0.224118
Epoch: 5/20, Batch: 10/469, Loss: 0.232740
Epoch: 5/20, Batch: 20/469, Loss: 0.220216
Epoch: 5/20, Batch: 30/469, Loss: 0.228610
Epoch: 5/20, Batch: 40/469, Loss: 0.223448
Epoch: 5/20, Batch: 50/469, Loss: 0.220628
Epoch: 5/20, Batch: 60/469, Loss: 0.231074
Epoch: 5/20, Batch: 70/469, Loss: 0.224305
Epoch: 5/20, Batch: 80/469, Loss: 0.220817
Epoch: 5/20, Batch: 90/469, Loss: 0.236191
Epoch: 5/20, Batch: 100/469, Loss: 0.223346
Epoch: 5/20, Batch: 110/469, Loss: 0.222555
Epoch: 5/20, Batch: 120/469, Loss: 0.218005
Epoch: 5/20, Batch: 130/469, Loss: 0.226730
Epoch: 5/20, Batch: 140/469, Loss: 0.214425
Epoch: 5/20, Batch: 150/469, Loss: 0.230344
Epoch: 5/20, Batch: 160/469, Loss: 0.214178
Epoch: 5/20, Batch: 170/469, Loss: 0.221805
Epoch: 5/20, Batch: 180/469, Loss: 0.218111
Epoch: 5/20, Batch: 190/469, Loss: 0.230607
Epoch: 5/20, Batch: 200/469, Loss: 0.222635
Epoch: 5/20, Batch: 210/469, Loss: 0.222057
Epoch: 5/20, Batch: 220/469, Loss: 0.222385
Epoch: 5/20, Batch: 230/469, Loss: 0.226100
Epoch: 5/20, Batch: 240/469, Loss: 0.229697
Epoch: 5/20, Batch: 250/469, Loss: 0.223248
Epoch: 5/20, Batch: 260/469, Loss: 0.227900
Epoch: 5/20, Batch: 270/469, Loss: 0.225837
Epoch: 5/20, Batch: 280/469, Loss: 0.217221
Epoch: 5/20, Batch: 290/469, Loss: 0.221690
Epoch: 5/20, Batch: 300/469, Loss: 0.228511
Epoch: 5/20, Batch: 310/469, Loss: 0.222221
Epoch: 5/20, Batch: 320/469, Loss: 0.218108
Epoch: 5/20, Batch: 330/469, Loss: 0.225754
Epoch: 5/20, Batch: 340/469, Loss: 0.225743
Epoch: 5/20, Batch: 350/469, Loss: 0.226366
Epoch: 5/20, Batch: 360/469, Loss: 0.222165
Epoch: 5/20, Batch: 370/469, Loss: 0.217813
Epoch: 5/20, Batch: 380/469, Loss: 0.230557
Epoch: 5/20, Batch: 390/469, Loss: 0.224621
Epoch: 5/20, Batch: 400/469, Loss: 0.217610
Epoch: 5/20, Batch: 410/469, Loss: 0.226195
Epoch: 5/20, Batch: 420/469, Loss: 0.216146
Epoch: 5/20, Batch: 430/469, Loss: 0.221329
Epoch: 5/20, Batch: 440/469, Loss: 0.223935
Epoch: 5/20, Batch: 450/469, Loss: 0.223515
Epoch: 5/20, Batch: 460/469, Loss: 0.218860
Epoch 5 average training loss: 0.223410
Epoch 5 test loss: 0.221102
Current learning rate: 0.001000
New best model with test loss: 0.221102
Epoch: 6/20, Batch: 0/469, Loss: 0.217836
Epoch: 6/20, Batch: 10/469, Loss: 0.225832
Epoch: 6/20, Batch: 20/469, Loss: 0.221874
Epoch: 6/20, Batch: 30/469, Loss: 0.213838
Epoch: 6/20, Batch: 40/469, Loss: 0.223980
Epoch: 6/20, Batch: 50/469, Loss: 0.216957
Epoch: 6/20, Batch: 60/469, Loss: 0.224077
Epoch: 6/20, Batch: 70/469, Loss: 0.219094
Epoch: 6/20, Batch: 80/469, Loss: 0.224078
Epoch: 6/20, Batch: 90/469, Loss: 0.216071
Epoch: 6/20, Batch: 100/469, Loss: 0.220423
Epoch: 6/20, Batch: 110/469, Loss: 0.218762
Epoch: 6/20, Batch: 120/469, Loss: 0.224305
Epoch: 6/20, Batch: 130/469, Loss: 0.219479
Epoch: 6/20, Batch: 140/469, Loss: 0.216908
Epoch: 6/20, Batch: 150/469, Loss: 0.210678
Epoch: 6/20, Batch: 160/469, Loss: 0.226148
Epoch: 6/20, Batch: 170/469, Loss: 0.220384
Epoch: 6/20, Batch: 180/469, Loss: 0.222741
Epoch: 6/20, Batch: 190/469, Loss: 0.224896
Epoch: 6/20, Batch: 200/469, Loss: 0.222351
Epoch: 6/20, Batch: 210/469, Loss: 0.225332
Epoch: 6/20, Batch: 220/469, Loss: 0.219862
Epoch: 6/20, Batch: 230/469, Loss: 0.227074
Epoch: 6/20, Batch: 240/469, Loss: 0.223254
Epoch: 6/20, Batch: 250/469, Loss: 0.218581
Epoch: 6/20, Batch: 260/469, Loss: 0.222126
Epoch: 6/20, Batch: 270/469, Loss: 0.218980
Epoch: 6/20, Batch: 280/469, Loss: 0.215863
Epoch: 6/20, Batch: 290/469, Loss: 0.218395
Epoch: 6/20, Batch: 300/469, Loss: 0.218484
Epoch: 6/20, Batch: 310/469, Loss: 0.216774
Epoch: 6/20, Batch: 320/469, Loss: 0.223385
Epoch: 6/20, Batch: 330/469, Loss: 0.222342
Epoch: 6/20, Batch: 340/469, Loss: 0.221791
Epoch: 6/20, Batch: 350/469, Loss: 0.222638
Epoch: 6/20, Batch: 360/469, Loss: 0.221570
Epoch: 6/20, Batch: 370/469, Loss: 0.226992
Epoch: 6/20, Batch: 380/469, Loss: 0.218076
Epoch: 6/20, Batch: 390/469, Loss: 0.228809
Epoch: 6/20, Batch: 400/469, Loss: 0.225015
Epoch: 6/20, Batch: 410/469, Loss: 0.217947
Epoch: 6/20, Batch: 420/469, Loss: 0.218748
Epoch: 6/20, Batch: 430/469, Loss: 0.218244
Epoch: 6/20, Batch: 440/469, Loss: 0.217017
Epoch: 6/20, Batch: 450/469, Loss: 0.221016
Epoch: 6/20, Batch: 460/469, Loss: 0.216728
Epoch 6 average training loss: 0.220853
Epoch 6 test loss: 0.220251
Current learning rate: 0.001000
New best model with test loss: 0.220251
Epoch: 7/20, Batch: 0/469, Loss: 0.224952
Epoch: 7/20, Batch: 10/469, Loss: 0.216069
Epoch: 7/20, Batch: 20/469, Loss: 0.225738
Epoch: 7/20, Batch: 30/469, Loss: 0.218473
Epoch: 7/20, Batch: 40/469, Loss: 0.208255
Epoch: 7/20, Batch: 50/469, Loss: 0.219465
Epoch: 7/20, Batch: 60/469, Loss: 0.216269
Epoch: 7/20, Batch: 70/469, Loss: 0.219640
Epoch: 7/20, Batch: 80/469, Loss: 0.213736
Epoch: 7/20, Batch: 90/469, Loss: 0.223662
Epoch: 7/20, Batch: 100/469, Loss: 0.216803
Epoch: 7/20, Batch: 110/469, Loss: 0.222347
Epoch: 7/20, Batch: 120/469, Loss: 0.223461
Epoch: 7/20, Batch: 130/469, Loss: 0.217972
Epoch: 7/20, Batch: 140/469, Loss: 0.219932
Epoch: 7/20, Batch: 150/469, Loss: 0.224264
Epoch: 7/20, Batch: 160/469, Loss: 0.218657
Epoch: 7/20, Batch: 170/469, Loss: 0.222183
Epoch: 7/20, Batch: 180/469, Loss: 0.229374
Epoch: 7/20, Batch: 190/469, Loss: 0.216462
Epoch: 7/20, Batch: 200/469, Loss: 0.212875
Epoch: 7/20, Batch: 210/469, Loss: 0.225409
Epoch: 7/20, Batch: 220/469, Loss: 0.224333
Epoch: 7/20, Batch: 230/469, Loss: 0.219892
Epoch: 7/20, Batch: 240/469, Loss: 0.218604
Epoch: 7/20, Batch: 250/469, Loss: 0.219289
Epoch: 7/20, Batch: 260/469, Loss: 0.211656
Epoch: 7/20, Batch: 270/469, Loss: 0.225197
Epoch: 7/20, Batch: 280/469, Loss: 0.210968
Epoch: 7/20, Batch: 290/469, Loss: 0.216885
Epoch: 7/20, Batch: 300/469, Loss: 0.214910
Epoch: 7/20, Batch: 310/469, Loss: 0.214594
Epoch: 7/20, Batch: 320/469, Loss: 0.213676
Epoch: 7/20, Batch: 330/469, Loss: 0.217838
Epoch: 7/20, Batch: 340/469, Loss: 0.213745
Epoch: 7/20, Batch: 350/469, Loss: 0.221465
Epoch: 7/20, Batch: 360/469, Loss: 0.219727
Epoch: 7/20, Batch: 370/469, Loss: 0.214327
Epoch: 7/20, Batch: 380/469, Loss: 0.215226
Epoch: 7/20, Batch: 390/469, Loss: 0.213002
Epoch: 7/20, Batch: 400/469, Loss: 0.218513
Epoch: 7/20, Batch: 410/469, Loss: 0.212558
Epoch: 7/20, Batch: 420/469, Loss: 0.218478
Epoch: 7/20, Batch: 430/469, Loss: 0.210598
Epoch: 7/20, Batch: 440/469, Loss: 0.219975
Epoch: 7/20, Batch: 450/469, Loss: 0.216140
Epoch: 7/20, Batch: 460/469, Loss: 0.212611
Epoch 7 average training loss: 0.219044
Epoch 7 test loss: 0.216958
Current learning rate: 0.001000
New best model with test loss: 0.216958
Epoch: 8/20, Batch: 0/469, Loss: 0.223242
Epoch: 8/20, Batch: 10/469, Loss: 0.222034
Epoch: 8/20, Batch: 20/469, Loss: 0.216805
Epoch: 8/20, Batch: 30/469, Loss: 0.219820
Epoch: 8/20, Batch: 40/469, Loss: 0.217796
Epoch: 8/20, Batch: 50/469, Loss: 0.217689
Epoch: 8/20, Batch: 60/469, Loss: 0.224524
Epoch: 8/20, Batch: 70/469, Loss: 0.215307
Epoch: 8/20, Batch: 80/469, Loss: 0.220199
Epoch: 8/20, Batch: 90/469, Loss: 0.208169
Epoch: 8/20, Batch: 100/469, Loss: 0.222147
Epoch: 8/20, Batch: 110/469, Loss: 0.221772
Epoch: 8/20, Batch: 120/469, Loss: 0.215154
Epoch: 8/20, Batch: 130/469, Loss: 0.219635
Epoch: 8/20, Batch: 140/469, Loss: 0.222720
Epoch: 8/20, Batch: 150/469, Loss: 0.212889
Epoch: 8/20, Batch: 160/469, Loss: 0.210679
Epoch: 8/20, Batch: 170/469, Loss: 0.221016
Epoch: 8/20, Batch: 180/469, Loss: 0.211662
Epoch: 8/20, Batch: 190/469, Loss: 0.216154
Epoch: 8/20, Batch: 200/469, Loss: 0.211343
Epoch: 8/20, Batch: 210/469, Loss: 0.218122
Epoch: 8/20, Batch: 220/469, Loss: 0.221368
Epoch: 8/20, Batch: 230/469, Loss: 0.218594
Epoch: 8/20, Batch: 240/469, Loss: 0.224074
Epoch: 8/20, Batch: 250/469, Loss: 0.220542
Epoch: 8/20, Batch: 260/469, Loss: 0.223187
Epoch: 8/20, Batch: 270/469, Loss: 0.218089
Epoch: 8/20, Batch: 280/469, Loss: 0.219223
Epoch: 8/20, Batch: 290/469, Loss: 0.217609
Epoch: 8/20, Batch: 300/469, Loss: 0.214193
Epoch: 8/20, Batch: 310/469, Loss: 0.215090
Epoch: 8/20, Batch: 320/469, Loss: 0.217212
Epoch: 8/20, Batch: 330/469, Loss: 0.215767
Epoch: 8/20, Batch: 340/469, Loss: 0.213468
Epoch: 8/20, Batch: 350/469, Loss: 0.212441
Epoch: 8/20, Batch: 360/469, Loss: 0.213226
Epoch: 8/20, Batch: 370/469, Loss: 0.220515
Epoch: 8/20, Batch: 380/469, Loss: 0.214867
Epoch: 8/20, Batch: 390/469, Loss: 0.213904
Epoch: 8/20, Batch: 400/469, Loss: 0.213580
Epoch: 8/20, Batch: 410/469, Loss: 0.217249
Epoch: 8/20, Batch: 420/469, Loss: 0.221295
Epoch: 8/20, Batch: 430/469, Loss: 0.220053
Epoch: 8/20, Batch: 440/469, Loss: 0.221695
Epoch: 8/20, Batch: 450/469, Loss: 0.212456
Epoch: 8/20, Batch: 460/469, Loss: 0.213071
Epoch 8 average training loss: 0.217465
Epoch 8 test loss: 0.215530
Current learning rate: 0.001000
New best model with test loss: 0.215530
Epoch: 9/20, Batch: 0/469, Loss: 0.220720
Epoch: 9/20, Batch: 10/469, Loss: 0.212772
Epoch: 9/20, Batch: 20/469, Loss: 0.213314
Epoch: 9/20, Batch: 30/469, Loss: 0.214651
Epoch: 9/20, Batch: 40/469, Loss: 0.219689
Epoch: 9/20, Batch: 50/469, Loss: 0.212246
Epoch: 9/20, Batch: 60/469, Loss: 0.217186
Epoch: 9/20, Batch: 70/469, Loss: 0.217723
Epoch: 9/20, Batch: 80/469, Loss: 0.213060
Epoch: 9/20, Batch: 90/469, Loss: 0.209208
Epoch: 9/20, Batch: 100/469, Loss: 0.223697
Epoch: 9/20, Batch: 110/469, Loss: 0.214186
Epoch: 9/20, Batch: 120/469, Loss: 0.212217
Epoch: 9/20, Batch: 130/469, Loss: 0.218618
Epoch: 9/20, Batch: 140/469, Loss: 0.209719
Epoch: 9/20, Batch: 150/469, Loss: 0.225030
Epoch: 9/20, Batch: 160/469, Loss: 0.210015
Epoch: 9/20, Batch: 170/469, Loss: 0.221006
Epoch: 9/20, Batch: 180/469, Loss: 0.213555
Epoch: 9/20, Batch: 190/469, Loss: 0.221309
Epoch: 9/20, Batch: 200/469, Loss: 0.221660
Epoch: 9/20, Batch: 210/469, Loss: 0.211858
Epoch: 9/20, Batch: 220/469, Loss: 0.213579
Epoch: 9/20, Batch: 230/469, Loss: 0.212624
Epoch: 9/20, Batch: 240/469, Loss: 0.218029
Epoch: 9/20, Batch: 250/469, Loss: 0.207524
Epoch: 9/20, Batch: 260/469, Loss: 0.221810
Epoch: 9/20, Batch: 270/469, Loss: 0.214967
Epoch: 9/20, Batch: 280/469, Loss: 0.213709
Epoch: 9/20, Batch: 290/469, Loss: 0.211457
Epoch: 9/20, Batch: 300/469, Loss: 0.212359
Epoch: 9/20, Batch: 310/469, Loss: 0.215434
Epoch: 9/20, Batch: 320/469, Loss: 0.211776
Epoch: 9/20, Batch: 330/469, Loss: 0.217738
Epoch: 9/20, Batch: 340/469, Loss: 0.209256
Epoch: 9/20, Batch: 350/469, Loss: 0.223196
Epoch: 9/20, Batch: 360/469, Loss: 0.214401
Epoch: 9/20, Batch: 370/469, Loss: 0.214568
Epoch: 9/20, Batch: 380/469, Loss: 0.215216
Epoch: 9/20, Batch: 390/469, Loss: 0.213957
Epoch: 9/20, Batch: 400/469, Loss: 0.209357
Epoch: 9/20, Batch: 410/469, Loss: 0.219765
Epoch: 9/20, Batch: 420/469, Loss: 0.217457
Epoch: 9/20, Batch: 430/469, Loss: 0.215437
Epoch: 9/20, Batch: 440/469, Loss: 0.204476
Epoch: 9/20, Batch: 450/469, Loss: 0.215620
Epoch: 9/20, Batch: 460/469, Loss: 0.215351
Epoch 9 average training loss: 0.215185
Epoch 9 test loss: 0.212113
Current learning rate: 0.001000
New best model with test loss: 0.212113
Epoch: 10/20, Batch: 0/469, Loss: 0.217446
Epoch: 10/20, Batch: 10/469, Loss: 0.204490
Epoch: 10/20, Batch: 20/469, Loss: 0.208787
Epoch: 10/20, Batch: 30/469, Loss: 0.214365
Epoch: 10/20, Batch: 40/469, Loss: 0.214507
Epoch: 10/20, Batch: 50/469, Loss: 0.213939
Epoch: 10/20, Batch: 60/469, Loss: 0.213126
Epoch: 10/20, Batch: 70/469, Loss: 0.208756
Epoch: 10/20, Batch: 80/469, Loss: 0.215561
Epoch: 10/20, Batch: 90/469, Loss: 0.210820
Epoch: 10/20, Batch: 100/469, Loss: 0.213531
Epoch: 10/20, Batch: 110/469, Loss: 0.211186
Epoch: 10/20, Batch: 120/469, Loss: 0.211150
Epoch: 10/20, Batch: 130/469, Loss: 0.205882
Epoch: 10/20, Batch: 140/469, Loss: 0.206015
Epoch: 10/20, Batch: 150/469, Loss: 0.205237
Epoch: 10/20, Batch: 160/469, Loss: 0.208499
Epoch: 10/20, Batch: 170/469, Loss: 0.213678
Epoch: 10/20, Batch: 180/469, Loss: 0.205830
Epoch: 10/20, Batch: 190/469, Loss: 0.220421
Epoch: 10/20, Batch: 200/469, Loss: 0.209436
Epoch: 10/20, Batch: 210/469, Loss: 0.208071
Epoch: 10/20, Batch: 220/469, Loss: 0.213684
Epoch: 10/20, Batch: 230/469, Loss: 0.211148
Epoch: 10/20, Batch: 240/469, Loss: 0.209753
Epoch: 10/20, Batch: 250/469, Loss: 0.207899
Epoch: 10/20, Batch: 260/469, Loss: 0.214689
Epoch: 10/20, Batch: 270/469, Loss: 0.203741
Epoch: 10/20, Batch: 280/469, Loss: 0.216980
Epoch: 10/20, Batch: 290/469, Loss: 0.215852
Epoch: 10/20, Batch: 300/469, Loss: 0.207830
Epoch: 10/20, Batch: 310/469, Loss: 0.216609
Epoch: 10/20, Batch: 320/469, Loss: 0.213840
Epoch: 10/20, Batch: 330/469, Loss: 0.209219
Epoch: 10/20, Batch: 340/469, Loss: 0.217362
Epoch: 10/20, Batch: 350/469, Loss: 0.216105
Epoch: 10/20, Batch: 360/469, Loss: 0.215934
Epoch: 10/20, Batch: 370/469, Loss: 0.211426
Epoch: 10/20, Batch: 380/469, Loss: 0.213747
Epoch: 10/20, Batch: 390/469, Loss: 0.213301
Epoch: 10/20, Batch: 400/469, Loss: 0.214042
Epoch: 10/20, Batch: 410/469, Loss: 0.202593
Epoch: 10/20, Batch: 420/469, Loss: 0.216908
Epoch: 10/20, Batch: 430/469, Loss: 0.214638
Epoch: 10/20, Batch: 440/469, Loss: 0.204698
Epoch: 10/20, Batch: 450/469, Loss: 0.210592
Epoch: 10/20, Batch: 460/469, Loss: 0.213882
Epoch 10 average training loss: 0.212729
Epoch 10 test loss: 0.211134
Current learning rate: 0.001000
New best model with test loss: 0.211134
Epoch: 11/20, Batch: 0/469, Loss: 0.214763
Epoch: 11/20, Batch: 10/469, Loss: 0.219020
Epoch: 11/20, Batch: 20/469, Loss: 0.213717
Epoch: 11/20, Batch: 30/469, Loss: 0.209371
Epoch: 11/20, Batch: 40/469, Loss: 0.207829
Epoch: 11/20, Batch: 50/469, Loss: 0.207739
Epoch: 11/20, Batch: 60/469, Loss: 0.216472
Epoch: 11/20, Batch: 70/469, Loss: 0.211675
Epoch: 11/20, Batch: 80/469, Loss: 0.210735
Epoch: 11/20, Batch: 90/469, Loss: 0.214069
Epoch: 11/20, Batch: 100/469, Loss: 0.212001
Epoch: 11/20, Batch: 110/469, Loss: 0.212480
Epoch: 11/20, Batch: 120/469, Loss: 0.216690
Epoch: 11/20, Batch: 130/469, Loss: 0.206841
Epoch: 11/20, Batch: 140/469, Loss: 0.216091
Epoch: 11/20, Batch: 150/469, Loss: 0.209724
Epoch: 11/20, Batch: 160/469, Loss: 0.212609
Epoch: 11/20, Batch: 170/469, Loss: 0.211049
Epoch: 11/20, Batch: 180/469, Loss: 0.219174
Epoch: 11/20, Batch: 190/469, Loss: 0.211215
Epoch: 11/20, Batch: 200/469, Loss: 0.220481
Epoch: 11/20, Batch: 210/469, Loss: 0.215575
Epoch: 11/20, Batch: 220/469, Loss: 0.212099
Epoch: 11/20, Batch: 230/469, Loss: 0.215937
Epoch: 11/20, Batch: 240/469, Loss: 0.215622
Epoch: 11/20, Batch: 250/469, Loss: 0.209118
Epoch: 11/20, Batch: 260/469, Loss: 0.219197
Epoch: 11/20, Batch: 270/469, Loss: 0.204003
Epoch: 11/20, Batch: 280/469, Loss: 0.214908
Epoch: 11/20, Batch: 290/469, Loss: 0.208491
Epoch: 11/20, Batch: 300/469, Loss: 0.207092
Epoch: 11/20, Batch: 310/469, Loss: 0.215001
Epoch: 11/20, Batch: 320/469, Loss: 0.215088
Epoch: 11/20, Batch: 330/469, Loss: 0.214138
Epoch: 11/20, Batch: 340/469, Loss: 0.215318
Epoch: 11/20, Batch: 350/469, Loss: 0.207700
Epoch: 11/20, Batch: 360/469, Loss: 0.215858
Epoch: 11/20, Batch: 370/469, Loss: 0.213861
Epoch: 11/20, Batch: 380/469, Loss: 0.207773
Epoch: 11/20, Batch: 390/469, Loss: 0.214390
Epoch: 11/20, Batch: 400/469, Loss: 0.213081
Epoch: 11/20, Batch: 410/469, Loss: 0.203974
Epoch: 11/20, Batch: 420/469, Loss: 0.215607
Epoch: 11/20, Batch: 430/469, Loss: 0.214710
Epoch: 11/20, Batch: 440/469, Loss: 0.200902
Epoch: 11/20, Batch: 450/469, Loss: 0.206424
Epoch: 11/20, Batch: 460/469, Loss: 0.210066
Epoch 11 average training loss: 0.211831
Epoch 11 test loss: 0.211603
Current learning rate: 0.001000
No improvement for 1 epochs
Epoch: 12/20, Batch: 0/469, Loss: 0.211043
Epoch: 12/20, Batch: 10/469, Loss: 0.206815
Epoch: 12/20, Batch: 20/469, Loss: 0.215231
Epoch: 12/20, Batch: 30/469, Loss: 0.205282
Epoch: 12/20, Batch: 40/469, Loss: 0.207909
Epoch: 12/20, Batch: 50/469, Loss: 0.212512
Epoch: 12/20, Batch: 60/469, Loss: 0.213266
Epoch: 12/20, Batch: 70/469, Loss: 0.206314
Epoch: 12/20, Batch: 80/469, Loss: 0.211810
Epoch: 12/20, Batch: 90/469, Loss: 0.210911
Epoch: 12/20, Batch: 100/469, Loss: 0.210767
Epoch: 12/20, Batch: 110/469, Loss: 0.211135
Epoch: 12/20, Batch: 120/469, Loss: 0.217679
Epoch: 12/20, Batch: 130/469, Loss: 0.211737
Epoch: 12/20, Batch: 140/469, Loss: 0.211754
Epoch: 12/20, Batch: 150/469, Loss: 0.211974
Epoch: 12/20, Batch: 160/469, Loss: 0.210566
Epoch: 12/20, Batch: 170/469, Loss: 0.217121
Epoch: 12/20, Batch: 180/469, Loss: 0.221854
Epoch: 12/20, Batch: 190/469, Loss: 0.209463
Epoch: 12/20, Batch: 200/469, Loss: 0.213184
Epoch: 12/20, Batch: 210/469, Loss: 0.213088
Epoch: 12/20, Batch: 220/469, Loss: 0.216206
Epoch: 12/20, Batch: 230/469, Loss: 0.211522
Epoch: 12/20, Batch: 240/469, Loss: 0.218441
Epoch: 12/20, Batch: 250/469, Loss: 0.213431
Epoch: 12/20, Batch: 260/469, Loss: 0.211362
Epoch: 12/20, Batch: 270/469, Loss: 0.216138
Epoch: 12/20, Batch: 280/469, Loss: 0.210831
Epoch: 12/20, Batch: 290/469, Loss: 0.207998
Epoch: 12/20, Batch: 300/469, Loss: 0.216001
Epoch: 12/20, Batch: 310/469, Loss: 0.210044
Epoch: 12/20, Batch: 320/469, Loss: 0.212161
Epoch: 12/20, Batch: 330/469, Loss: 0.212306
Epoch: 12/20, Batch: 340/469, Loss: 0.203685
Epoch: 12/20, Batch: 350/469, Loss: 0.210994
Epoch: 12/20, Batch: 360/469, Loss: 0.218511
Epoch: 12/20, Batch: 370/469, Loss: 0.207289
Epoch: 12/20, Batch: 380/469, Loss: 0.213954
Epoch: 12/20, Batch: 390/469, Loss: 0.213009
Epoch: 12/20, Batch: 400/469, Loss: 0.210243
Epoch: 12/20, Batch: 410/469, Loss: 0.217228
Epoch: 12/20, Batch: 420/469, Loss: 0.204950
Epoch: 12/20, Batch: 430/469, Loss: 0.206365
Epoch: 12/20, Batch: 440/469, Loss: 0.207523
Epoch: 12/20, Batch: 450/469, Loss: 0.212558
Epoch: 12/20, Batch: 460/469, Loss: 0.214329
Epoch 12 average training loss: 0.211093
Epoch 12 test loss: 0.209030
Current learning rate: 0.001000
New best model with test loss: 0.209030
Epoch: 13/20, Batch: 0/469, Loss: 0.210873
Epoch: 13/20, Batch: 10/469, Loss: 0.206721
Epoch: 13/20, Batch: 20/469, Loss: 0.219018
Epoch: 13/20, Batch: 30/469, Loss: 0.212593
Epoch: 13/20, Batch: 40/469, Loss: 0.212240
Epoch: 13/20, Batch: 50/469, Loss: 0.217745
Epoch: 13/20, Batch: 60/469, Loss: 0.212483
Epoch: 13/20, Batch: 70/469, Loss: 0.210983
Epoch: 13/20, Batch: 80/469, Loss: 0.210471
Epoch: 13/20, Batch: 90/469, Loss: 0.206260
Epoch: 13/20, Batch: 100/469, Loss: 0.202827
Epoch: 13/20, Batch: 110/469, Loss: 0.212014
Epoch: 13/20, Batch: 120/469, Loss: 0.206353
Epoch: 13/20, Batch: 130/469, Loss: 0.212223
Epoch: 13/20, Batch: 140/469, Loss: 0.209265
Epoch: 13/20, Batch: 150/469, Loss: 0.209393
Epoch: 13/20, Batch: 160/469, Loss: 0.207558
Epoch: 13/20, Batch: 170/469, Loss: 0.208564
Epoch: 13/20, Batch: 180/469, Loss: 0.206930
Epoch: 13/20, Batch: 190/469, Loss: 0.214779
Epoch: 13/20, Batch: 200/469, Loss: 0.211891
Epoch: 13/20, Batch: 210/469, Loss: 0.202581
Epoch: 13/20, Batch: 220/469, Loss: 0.211638
Epoch: 13/20, Batch: 230/469, Loss: 0.215342
Epoch: 13/20, Batch: 240/469, Loss: 0.208895
Epoch: 13/20, Batch: 250/469, Loss: 0.212216
Epoch: 13/20, Batch: 260/469, Loss: 0.207107
Epoch: 13/20, Batch: 270/469, Loss: 0.203087
Epoch: 13/20, Batch: 280/469, Loss: 0.212842
Epoch: 13/20, Batch: 290/469, Loss: 0.207821
Epoch: 13/20, Batch: 300/469, Loss: 0.210786
Epoch: 13/20, Batch: 310/469, Loss: 0.202894
Epoch: 13/20, Batch: 320/469, Loss: 0.204934
Epoch: 13/20, Batch: 330/469, Loss: 0.209444
Epoch: 13/20, Batch: 340/469, Loss: 0.211696
Epoch: 13/20, Batch: 350/469, Loss: 0.211397
Epoch: 13/20, Batch: 360/469, Loss: 0.212550
Epoch: 13/20, Batch: 370/469, Loss: 0.206205
Epoch: 13/20, Batch: 380/469, Loss: 0.206187
Epoch: 13/20, Batch: 390/469, Loss: 0.214461
Epoch: 13/20, Batch: 400/469, Loss: 0.208282
Epoch: 13/20, Batch: 410/469, Loss: 0.212585
Epoch: 13/20, Batch: 420/469, Loss: 0.205367
Epoch: 13/20, Batch: 430/469, Loss: 0.200762
Epoch: 13/20, Batch: 440/469, Loss: 0.207483
Epoch: 13/20, Batch: 450/469, Loss: 0.209540
Epoch: 13/20, Batch: 460/469, Loss: 0.216215
Epoch 13 average training loss: 0.210340
Epoch 13 test loss: 0.208866
Current learning rate: 0.001000
New best model with test loss: 0.208866
Epoch: 14/20, Batch: 0/469, Loss: 0.212771
Epoch: 14/20, Batch: 10/469, Loss: 0.213373
Epoch: 14/20, Batch: 20/469, Loss: 0.199179
Epoch: 14/20, Batch: 30/469, Loss: 0.208138
Epoch: 14/20, Batch: 40/469, Loss: 0.205600
Epoch: 14/20, Batch: 50/469, Loss: 0.200619
Epoch: 14/20, Batch: 60/469, Loss: 0.209880
Epoch: 14/20, Batch: 70/469, Loss: 0.209947
Epoch: 14/20, Batch: 80/469, Loss: 0.221133
Epoch: 14/20, Batch: 90/469, Loss: 0.212685
Epoch: 14/20, Batch: 100/469, Loss: 0.213046
Epoch: 14/20, Batch: 110/469, Loss: 0.209327
Epoch: 14/20, Batch: 120/469, Loss: 0.208958
Epoch: 14/20, Batch: 130/469, Loss: 0.209591
Epoch: 14/20, Batch: 140/469, Loss: 0.205718
Epoch: 14/20, Batch: 150/469, Loss: 0.207877
Epoch: 14/20, Batch: 160/469, Loss: 0.210706
Epoch: 14/20, Batch: 170/469, Loss: 0.213620
Epoch: 14/20, Batch: 180/469, Loss: 0.206293
Epoch: 14/20, Batch: 190/469, Loss: 0.205369
Epoch: 14/20, Batch: 200/469, Loss: 0.201762
Epoch: 14/20, Batch: 210/469, Loss: 0.211083
Epoch: 14/20, Batch: 220/469, Loss: 0.205765
Epoch: 14/20, Batch: 230/469, Loss: 0.208427
Epoch: 14/20, Batch: 240/469, Loss: 0.209432
Epoch: 14/20, Batch: 250/469, Loss: 0.206332
Epoch: 14/20, Batch: 260/469, Loss: 0.206169
Epoch: 14/20, Batch: 270/469, Loss: 0.210023
Epoch: 14/20, Batch: 280/469, Loss: 0.211498
Epoch: 14/20, Batch: 290/469, Loss: 0.206233
Epoch: 14/20, Batch: 300/469, Loss: 0.208469
Epoch: 14/20, Batch: 310/469, Loss: 0.197393
Epoch: 14/20, Batch: 320/469, Loss: 0.213416
Epoch: 14/20, Batch: 330/469, Loss: 0.202493
Epoch: 14/20, Batch: 340/469, Loss: 0.210662
Epoch: 14/20, Batch: 350/469, Loss: 0.214152
Epoch: 14/20, Batch: 360/469, Loss: 0.209098
Epoch: 14/20, Batch: 370/469, Loss: 0.208248
Epoch: 14/20, Batch: 380/469, Loss: 0.203246
Epoch: 14/20, Batch: 390/469, Loss: 0.209957
Epoch: 14/20, Batch: 400/469, Loss: 0.207177
Epoch: 14/20, Batch: 410/469, Loss: 0.204137
Epoch: 14/20, Batch: 420/469, Loss: 0.209079
Epoch: 14/20, Batch: 430/469, Loss: 0.210575
Epoch: 14/20, Batch: 440/469, Loss: 0.209693
Epoch: 14/20, Batch: 450/469, Loss: 0.216841
Epoch: 14/20, Batch: 460/469, Loss: 0.211794
Epoch 14 average training loss: 0.209595
Epoch 14 test loss: 0.209902
Current learning rate: 0.001000
No improvement for 1 epochs
Epoch: 15/20, Batch: 0/469, Loss: 0.207367
Epoch: 15/20, Batch: 10/469, Loss: 0.204525
Epoch: 15/20, Batch: 20/469, Loss: 0.208534
Epoch: 15/20, Batch: 30/469, Loss: 0.213829
Epoch: 15/20, Batch: 40/469, Loss: 0.206152
Epoch: 15/20, Batch: 50/469, Loss: 0.206549
Epoch: 15/20, Batch: 60/469, Loss: 0.211006
Epoch: 15/20, Batch: 70/469, Loss: 0.201411
Epoch: 15/20, Batch: 80/469, Loss: 0.208558
Epoch: 15/20, Batch: 90/469, Loss: 0.208011
Epoch: 15/20, Batch: 100/469, Loss: 0.219043
Epoch: 15/20, Batch: 110/469, Loss: 0.204353
Epoch: 15/20, Batch: 120/469, Loss: 0.212703
Epoch: 15/20, Batch: 130/469, Loss: 0.211189
Epoch: 15/20, Batch: 140/469, Loss: 0.209152
Epoch: 15/20, Batch: 150/469, Loss: 0.206217
Epoch: 15/20, Batch: 160/469, Loss: 0.207095
Epoch: 15/20, Batch: 170/469, Loss: 0.201941
Epoch: 15/20, Batch: 180/469, Loss: 0.218086
Epoch: 15/20, Batch: 190/469, Loss: 0.204842
Epoch: 15/20, Batch: 200/469, Loss: 0.208185
Epoch: 15/20, Batch: 210/469, Loss: 0.212400
Epoch: 15/20, Batch: 220/469, Loss: 0.216292
Epoch: 15/20, Batch: 230/469, Loss: 0.212196
Epoch: 15/20, Batch: 240/469, Loss: 0.210542
Epoch: 15/20, Batch: 250/469, Loss: 0.202203
Epoch: 15/20, Batch: 260/469, Loss: 0.209709
Epoch: 15/20, Batch: 270/469, Loss: 0.203776
Epoch: 15/20, Batch: 280/469, Loss: 0.209241
Epoch: 15/20, Batch: 290/469, Loss: 0.203724
Epoch: 15/20, Batch: 300/469, Loss: 0.210055
Epoch: 15/20, Batch: 310/469, Loss: 0.209016
Epoch: 15/20, Batch: 320/469, Loss: 0.218477
Epoch: 15/20, Batch: 330/469, Loss: 0.207926
Epoch: 15/20, Batch: 340/469, Loss: 0.213963
Epoch: 15/20, Batch: 350/469, Loss: 0.206167
Epoch: 15/20, Batch: 360/469, Loss: 0.212611
Epoch: 15/20, Batch: 370/469, Loss: 0.205985
Epoch: 15/20, Batch: 380/469, Loss: 0.206174
Epoch: 15/20, Batch: 390/469, Loss: 0.202105
Epoch: 15/20, Batch: 400/469, Loss: 0.208659
Epoch: 15/20, Batch: 410/469, Loss: 0.207371
Epoch: 15/20, Batch: 420/469, Loss: 0.212601
Epoch: 15/20, Batch: 430/469, Loss: 0.204636
Epoch: 15/20, Batch: 440/469, Loss: 0.216152
Epoch: 15/20, Batch: 450/469, Loss: 0.207220
Epoch: 15/20, Batch: 460/469, Loss: 0.208323
Epoch 15 average training loss: 0.209439
Epoch 15 test loss: 0.209275
Current learning rate: 0.001000
No improvement for 2 epochs
Epoch: 16/20, Batch: 0/469, Loss: 0.215293
Epoch: 16/20, Batch: 10/469, Loss: 0.213600
Epoch: 16/20, Batch: 20/469, Loss: 0.210042
Epoch: 16/20, Batch: 30/469, Loss: 0.210281
Epoch: 16/20, Batch: 40/469, Loss: 0.206723
Epoch: 16/20, Batch: 50/469, Loss: 0.210724
Epoch: 16/20, Batch: 60/469, Loss: 0.215392
Epoch: 16/20, Batch: 70/469, Loss: 0.212187
Epoch: 16/20, Batch: 80/469, Loss: 0.213717
Epoch: 16/20, Batch: 90/469, Loss: 0.208688
Epoch: 16/20, Batch: 100/469, Loss: 0.208639
Epoch: 16/20, Batch: 110/469, Loss: 0.212730
Epoch: 16/20, Batch: 120/469, Loss: 0.217212
Epoch: 16/20, Batch: 130/469, Loss: 0.205065
Epoch: 16/20, Batch: 140/469, Loss: 0.210744
Epoch: 16/20, Batch: 150/469, Loss: 0.205616
Epoch: 16/20, Batch: 160/469, Loss: 0.206760
Epoch: 16/20, Batch: 170/469, Loss: 0.205589
Epoch: 16/20, Batch: 180/469, Loss: 0.208279
Epoch: 16/20, Batch: 190/469, Loss: 0.215898
Epoch: 16/20, Batch: 200/469, Loss: 0.209263
Epoch: 16/20, Batch: 210/469, Loss: 0.209366
Epoch: 16/20, Batch: 220/469, Loss: 0.203065
Epoch: 16/20, Batch: 230/469, Loss: 0.206546
Epoch: 16/20, Batch: 240/469, Loss: 0.207617
Epoch: 16/20, Batch: 250/469, Loss: 0.209491
Epoch: 16/20, Batch: 260/469, Loss: 0.213042
Epoch: 16/20, Batch: 270/469, Loss: 0.213224
Epoch: 16/20, Batch: 280/469, Loss: 0.209430
Epoch: 16/20, Batch: 290/469, Loss: 0.212271
Epoch: 16/20, Batch: 300/469, Loss: 0.212083
Epoch: 16/20, Batch: 310/469, Loss: 0.210680
Epoch: 16/20, Batch: 320/469, Loss: 0.208942
Epoch: 16/20, Batch: 330/469, Loss: 0.209899
Epoch: 16/20, Batch: 340/469, Loss: 0.215008
Epoch: 16/20, Batch: 350/469, Loss: 0.207944
Epoch: 16/20, Batch: 360/469, Loss: 0.214226
Epoch: 16/20, Batch: 370/469, Loss: 0.205244
Epoch: 16/20, Batch: 380/469, Loss: 0.201590
Epoch: 16/20, Batch: 390/469, Loss: 0.202800
Epoch: 16/20, Batch: 400/469, Loss: 0.205809
Epoch: 16/20, Batch: 410/469, Loss: 0.202333
Epoch: 16/20, Batch: 420/469, Loss: 0.214613
Epoch: 16/20, Batch: 430/469, Loss: 0.206397
Epoch: 16/20, Batch: 440/469, Loss: 0.208402
Epoch: 16/20, Batch: 450/469, Loss: 0.202472
Epoch: 16/20, Batch: 460/469, Loss: 0.203844
Epoch 16 average training loss: 0.209325
Epoch 16 test loss: 0.206827
Current learning rate: 0.001000
New best model with test loss: 0.206827
Epoch: 17/20, Batch: 0/469, Loss: 0.202556
Epoch: 17/20, Batch: 10/469, Loss: 0.207812
Epoch: 17/20, Batch: 20/469, Loss: 0.204858
Epoch: 17/20, Batch: 30/469, Loss: 0.211211
Epoch: 17/20, Batch: 40/469, Loss: 0.200099
Epoch: 17/20, Batch: 50/469, Loss: 0.206627
Epoch: 17/20, Batch: 60/469, Loss: 0.208037
Epoch: 17/20, Batch: 70/469, Loss: 0.208377
Epoch: 17/20, Batch: 80/469, Loss: 0.210629
Epoch: 17/20, Batch: 90/469, Loss: 0.214234
Epoch: 17/20, Batch: 100/469, Loss: 0.212004
Epoch: 17/20, Batch: 110/469, Loss: 0.212999
Epoch: 17/20, Batch: 120/469, Loss: 0.209832
Epoch: 17/20, Batch: 130/469, Loss: 0.210132
Epoch: 17/20, Batch: 140/469, Loss: 0.209750
Epoch: 17/20, Batch: 150/469, Loss: 0.216717
Epoch: 17/20, Batch: 160/469, Loss: 0.201560
Epoch: 17/20, Batch: 170/469, Loss: 0.209595
Epoch: 17/20, Batch: 180/469, Loss: 0.209703
Epoch: 17/20, Batch: 190/469, Loss: 0.218761
Epoch: 17/20, Batch: 200/469, Loss: 0.206203
Epoch: 17/20, Batch: 210/469, Loss: 0.212779
Epoch: 17/20, Batch: 220/469, Loss: 0.205639
Epoch: 17/20, Batch: 230/469, Loss: 0.210037
Epoch: 17/20, Batch: 240/469, Loss: 0.209341
Epoch: 17/20, Batch: 250/469, Loss: 0.213321
Epoch: 17/20, Batch: 260/469, Loss: 0.204016
Epoch: 17/20, Batch: 270/469, Loss: 0.206974
Epoch: 17/20, Batch: 280/469, Loss: 0.211707
Epoch: 17/20, Batch: 290/469, Loss: 0.218377
Epoch: 17/20, Batch: 300/469, Loss: 0.209888
Epoch: 17/20, Batch: 310/469, Loss: 0.213094
Epoch: 17/20, Batch: 320/469, Loss: 0.206226
Epoch: 17/20, Batch: 330/469, Loss: 0.207443
Epoch: 17/20, Batch: 340/469, Loss: 0.206221
Epoch: 17/20, Batch: 350/469, Loss: 0.209708
Epoch: 17/20, Batch: 360/469, Loss: 0.208175
Epoch: 17/20, Batch: 370/469, Loss: 0.215754
Epoch: 17/20, Batch: 380/469, Loss: 0.206180
Epoch: 17/20, Batch: 390/469, Loss: 0.213457
Epoch: 17/20, Batch: 400/469, Loss: 0.207036
Epoch: 17/20, Batch: 410/469, Loss: 0.207870
Epoch: 17/20, Batch: 420/469, Loss: 0.207152
Epoch: 17/20, Batch: 430/469, Loss: 0.207636
Epoch: 17/20, Batch: 440/469, Loss: 0.206491
Epoch: 17/20, Batch: 450/469, Loss: 0.208486
Epoch: 17/20, Batch: 460/469, Loss: 0.207429
Epoch 17 average training loss: 0.209234
Epoch 17 test loss: 0.207524
Current learning rate: 0.001000
No improvement for 1 epochs
Epoch: 18/20, Batch: 0/469, Loss: 0.202099
Epoch: 18/20, Batch: 10/469, Loss: 0.214507
Epoch: 18/20, Batch: 20/469, Loss: 0.207212
Epoch: 18/20, Batch: 30/469, Loss: 0.216066
Epoch: 18/20, Batch: 40/469, Loss: 0.213026
Epoch: 18/20, Batch: 50/469, Loss: 0.220747
Epoch: 18/20, Batch: 60/469, Loss: 0.210804
Epoch: 18/20, Batch: 70/469, Loss: 0.214038
Epoch: 18/20, Batch: 80/469, Loss: 0.207343
Epoch: 18/20, Batch: 90/469, Loss: 0.214469
Epoch: 18/20, Batch: 100/469, Loss: 0.208596
Epoch: 18/20, Batch: 110/469, Loss: 0.212159
Epoch: 18/20, Batch: 120/469, Loss: 0.204651
Epoch: 18/20, Batch: 130/469, Loss: 0.211621
Epoch: 18/20, Batch: 140/469, Loss: 0.210633
Epoch: 18/20, Batch: 150/469, Loss: 0.210425
Epoch: 18/20, Batch: 160/469, Loss: 0.203219
Epoch: 18/20, Batch: 170/469, Loss: 0.211141
Epoch: 18/20, Batch: 180/469, Loss: 0.209775
Epoch: 18/20, Batch: 190/469, Loss: 0.211429
Epoch: 18/20, Batch: 200/469, Loss: 0.205494
Epoch: 18/20, Batch: 210/469, Loss: 0.218133
Epoch: 18/20, Batch: 220/469, Loss: 0.204454
Epoch: 18/20, Batch: 230/469, Loss: 0.210245
Epoch: 18/20, Batch: 240/469, Loss: 0.205546
Epoch: 18/20, Batch: 250/469, Loss: 0.207433
Epoch: 18/20, Batch: 260/469, Loss: 0.208362
Epoch: 18/20, Batch: 270/469, Loss: 0.209357
Epoch: 18/20, Batch: 280/469, Loss: 0.211110
Epoch: 18/20, Batch: 290/469, Loss: 0.209117
Epoch: 18/20, Batch: 300/469, Loss: 0.204339
Epoch: 18/20, Batch: 310/469, Loss: 0.208376
Epoch: 18/20, Batch: 320/469, Loss: 0.213858
Epoch: 18/20, Batch: 330/469, Loss: 0.213692
Epoch: 18/20, Batch: 340/469, Loss: 0.207779
Epoch: 18/20, Batch: 350/469, Loss: 0.205959
Epoch: 18/20, Batch: 360/469, Loss: 0.203879
Epoch: 18/20, Batch: 370/469, Loss: 0.204162
Epoch: 18/20, Batch: 380/469, Loss: 0.205290
Epoch: 18/20, Batch: 390/469, Loss: 0.212493
Epoch: 18/20, Batch: 400/469, Loss: 0.200532
Epoch: 18/20, Batch: 410/469, Loss: 0.211167
Epoch: 18/20, Batch: 420/469, Loss: 0.208053
Epoch: 18/20, Batch: 430/469, Loss: 0.214512
Epoch: 18/20, Batch: 440/469, Loss: 0.201038
Epoch: 18/20, Batch: 450/469, Loss: 0.213429
Epoch: 18/20, Batch: 460/469, Loss: 0.204935
Epoch 18 average training loss: 0.209099
Epoch 18 test loss: 0.206903
Current learning rate: 0.001000
No improvement for 2 epochs
Epoch: 19/20, Batch: 0/469, Loss: 0.207390
Epoch: 19/20, Batch: 10/469, Loss: 0.206777
Epoch: 19/20, Batch: 20/469, Loss: 0.212507
Epoch: 19/20, Batch: 30/469, Loss: 0.209222
Epoch: 19/20, Batch: 40/469, Loss: 0.203609
Epoch: 19/20, Batch: 50/469, Loss: 0.207364
Epoch: 19/20, Batch: 60/469, Loss: 0.202807
Epoch: 19/20, Batch: 70/469, Loss: 0.206227
Epoch: 19/20, Batch: 80/469, Loss: 0.207104
Epoch: 19/20, Batch: 90/469, Loss: 0.206067
Epoch: 19/20, Batch: 100/469, Loss: 0.207564
Epoch: 19/20, Batch: 110/469, Loss: 0.206372
Epoch: 19/20, Batch: 120/469, Loss: 0.212251
Epoch: 19/20, Batch: 130/469, Loss: 0.217493
Epoch: 19/20, Batch: 140/469, Loss: 0.208805
Epoch: 19/20, Batch: 150/469, Loss: 0.213019
Epoch: 19/20, Batch: 160/469, Loss: 0.210528
Epoch: 19/20, Batch: 170/469, Loss: 0.209955
Epoch: 19/20, Batch: 180/469, Loss: 0.211152
Epoch: 19/20, Batch: 190/469, Loss: 0.205525
Epoch: 19/20, Batch: 200/469, Loss: 0.211048
Epoch: 19/20, Batch: 210/469, Loss: 0.212028
Epoch: 19/20, Batch: 220/469, Loss: 0.208617
Epoch: 19/20, Batch: 230/469, Loss: 0.208809
Epoch: 19/20, Batch: 240/469, Loss: 0.203364
Epoch: 19/20, Batch: 250/469, Loss: 0.210144
Epoch: 19/20, Batch: 260/469, Loss: 0.210097
Epoch: 19/20, Batch: 270/469, Loss: 0.214254
Epoch: 19/20, Batch: 280/469, Loss: 0.208904
Epoch: 19/20, Batch: 290/469, Loss: 0.218724
Epoch: 19/20, Batch: 300/469, Loss: 0.199022
Epoch: 19/20, Batch: 310/469, Loss: 0.210516
Epoch: 19/20, Batch: 320/469, Loss: 0.209234
Epoch: 19/20, Batch: 330/469, Loss: 0.210159
Epoch: 19/20, Batch: 340/469, Loss: 0.198755
Epoch: 19/20, Batch: 350/469, Loss: 0.211019
Epoch: 19/20, Batch: 360/469, Loss: 0.199160
Epoch: 19/20, Batch: 370/469, Loss: 0.210670
Epoch: 19/20, Batch: 380/469, Loss: 0.212555
Epoch: 19/20, Batch: 390/469, Loss: 0.212382
Epoch: 19/20, Batch: 400/469, Loss: 0.207561
Epoch: 19/20, Batch: 410/469, Loss: 0.209851
Epoch: 19/20, Batch: 420/469, Loss: 0.210867
Epoch: 19/20, Batch: 430/469, Loss: 0.209916
Epoch: 19/20, Batch: 440/469, Loss: 0.206035
Epoch: 19/20, Batch: 450/469, Loss: 0.206083
Epoch: 19/20, Batch: 460/469, Loss: 0.207148
Epoch 19 average training loss: 0.208926
Epoch 19 test loss: 0.210435
Current learning rate: 0.000500
No improvement for 3 epochs
Epoch: 20/20, Batch: 0/469, Loss: 0.211074
Epoch: 20/20, Batch: 10/469, Loss: 0.204439
Epoch: 20/20, Batch: 20/469, Loss: 0.208146
Epoch: 20/20, Batch: 30/469, Loss: 0.201846
Epoch: 20/20, Batch: 40/469, Loss: 0.207145
Epoch: 20/20, Batch: 50/469, Loss: 0.207307
Epoch: 20/20, Batch: 60/469, Loss: 0.202797
Epoch: 20/20, Batch: 70/469, Loss: 0.210327
Epoch: 20/20, Batch: 80/469, Loss: 0.205838
Epoch: 20/20, Batch: 90/469, Loss: 0.203876
Epoch: 20/20, Batch: 100/469, Loss: 0.205854
Epoch: 20/20, Batch: 110/469, Loss: 0.204464
Epoch: 20/20, Batch: 120/469, Loss: 0.206882
Epoch: 20/20, Batch: 130/469, Loss: 0.206017
Epoch: 20/20, Batch: 140/469, Loss: 0.202314
Epoch: 20/20, Batch: 150/469, Loss: 0.209197
Epoch: 20/20, Batch: 160/469, Loss: 0.204208
Epoch: 20/20, Batch: 170/469, Loss: 0.203223
Epoch: 20/20, Batch: 180/469, Loss: 0.203935
Epoch: 20/20, Batch: 190/469, Loss: 0.206490
Epoch: 20/20, Batch: 200/469, Loss: 0.204954
Epoch: 20/20, Batch: 210/469, Loss: 0.207911
Epoch: 20/20, Batch: 220/469, Loss: 0.210991
Epoch: 20/20, Batch: 230/469, Loss: 0.199542
Epoch: 20/20, Batch: 240/469, Loss: 0.202942
Epoch: 20/20, Batch: 250/469, Loss: 0.207271
Epoch: 20/20, Batch: 260/469, Loss: 0.203368
Epoch: 20/20, Batch: 270/469, Loss: 0.198808
Epoch: 20/20, Batch: 280/469, Loss: 0.203112
Epoch: 20/20, Batch: 290/469, Loss: 0.203000
Epoch: 20/20, Batch: 300/469, Loss: 0.202056
Epoch: 20/20, Batch: 310/469, Loss: 0.205928
Epoch: 20/20, Batch: 320/469, Loss: 0.201867
Epoch: 20/20, Batch: 330/469, Loss: 0.204746
Epoch: 20/20, Batch: 340/469, Loss: 0.201211
Epoch: 20/20, Batch: 350/469, Loss: 0.204882
Epoch: 20/20, Batch: 360/469, Loss: 0.205628
Epoch: 20/20, Batch: 370/469, Loss: 0.203249
Epoch: 20/20, Batch: 380/469, Loss: 0.205749
Epoch: 20/20, Batch: 390/469, Loss: 0.204313
Epoch: 20/20, Batch: 400/469, Loss: 0.201741
Epoch: 20/20, Batch: 410/469, Loss: 0.207491
Epoch: 20/20, Batch: 420/469, Loss: 0.205125
Epoch: 20/20, Batch: 430/469, Loss: 0.196546
Epoch: 20/20, Batch: 440/469, Loss: 0.207992
Epoch: 20/20, Batch: 450/469, Loss: 0.203138
Epoch: 20/20, Batch: 460/469, Loss: 0.207264
Epoch 20 average training loss: 0.205626
Epoch 20 test loss: 0.204501
Current learning rate: 0.000500
New best model with test loss: 0.204501
Loaded best model with test loss: 0.204501
Final Test Loss: 0.2045
No description has been provided for this image
samples shape:  (100, 28, 28, 3)
No description has been provided for this image

Question 3: Causal Transformer - iGPT¶

Now we will move onto the current most popular and widespread autoregressive model, the transformer.

Part (a) Autoregressive Transformer on Shapes and MNIST¶

In this part, implement a simple Autoregressive Transformer to model binary MNIST and shapes images (same as Q2(a), but with a Transformer).

Some additional notes about your transformer implementation:

  • iGPT uses learned positional encodings. We recommend to use those here as well. However, you may also use sinusoidal positional encodings if you wish (see the Attention is All You Need paper)
  • Autoregressive transformer always predicts the next token, give prior tokens. iGPT has a special <bos> or beginning of sequence token at the start of every sequence every image. Make sure to include this in your implementation as well. You can generate unconditional sample by conditioning with the <bos> token.
  • While dropout is a common feature in transformer models, you do not need to add it (but may if you wish!).
  • Prebuilt transformers exist in some frameworks (i.e. pytorch). Don't just use an off the shelf implementation as the point of the exercise is to better understand the transformer architecture. Building the transformer from the ground up (use primitives such as Linear/Dense layers, LayerNorm, GeLU, Embedding)
  • Learning rate warmup and cos learning rate decay are often used when training transformers to improve training stability and improve performance. See if this helps your model! Try 1000 steps of warmup with a cosine learning rate decay.

Paper references

  • Attention Is All You Need
  • Generative Pretraining from Pixels
  • Language Models are Unsupervised Multitask Learners

We recommend the following network design parameters:

  • $d_{model}$: 128
  • heads: 4
  • layers: 2
  • GeLU nonlinearities

And the following hyperparameters:

  • Batch size: 64 or 32 or 16 (whichever fits in your GPU)
  • Learning rate: $10^{-3}$
  • 15 epochs or more
  • Adam Optimizer (this applies to all Transformers models trained in future parts)

You will provide these deliverables

  1. Over the course of training, record the average negative log-likelihood (nats / dim) of the training data (per minibatch) and test data (for your entire test set). Code is provided that automatically plots the training curves.
  2. Report the final test set performance of your final model
  3. 100 samples from the final trained model
In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
class ScaledDotProductAttention(nn.Module):
    def __init__(self, dropout=0.0):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        """
        q: (batch_size, n_heads, seq_len, head_size)
        k: (batch_size, n_heads, seq_len, head_size)
        v: (batch_size, n_heads, seq_len, head_size)
        """
        d_k = q.shape[-1]
        scores = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(d_k)  # (batch_size, n_heads, seq_len, seq_len)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attention_weights = torch.softmax(scores, dim=-1)  # (batch_size, n_heads, seq_len, seq_len)
        attention_weights = self.dropout(attention_weights)  # (batch_size, n_heads, seq_len, seq_len)

        output = torch.matmul(attention_weights, v)  # (batch_size, n_heads, seq_len, head_size)
        return output

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.0, cache=False):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_size = d_model // n_heads
        self.use_cache = cache

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        self.out = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.attention = ScaledDotProductAttention(dropout=dropout)
        self.cached_k = None
        self.cached_v = None

    def split_heads(self, x):
        """
        x: (batch_size, seq_len, d_model)
        """
        batch_size, seq_len, d_model = x.shape
        return x.view(batch_size, seq_len, self.n_heads, self.head_size).transpose(1, 2)  # (batch_size, n_heads, seq_len, head_size)
    
    def combine_heads(self, x):
        """
        x: (batch_size, n_heads, seq_len, head_size)
        """
        batch_size, n_heads, seq_len, head_size = x.shape
        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)  # (batch_size, seq_len, d_model)
    
    def forward(self, x, mask=None, use_cache=False, past_key_values=None):
        batch_size, seq_len, d_model = x.shape
        if past_key_values is not None:
            self.cached_k, self.cached_v = past_key_values
        q = self.W_q(x)  # (batch_size, seq_len, d_model)
        k = self.W_k(x)
        v = self.W_v(x)

        q = self.split_heads(q)  # (batch_size, n_heads, seq_len, head_size)
        k = self.split_heads(k)
        v = self.split_heads(v)
        
        # Use KV cache if enabled
        if use_cache and self.cached_k is not None and self.cached_v is not None:
            # Concatenate current k, v with cached k, v
            k = torch.cat([self.cached_k, k], dim=2)
            v = torch.cat([self.cached_v, v], dim=2)


            self.cached_k = k
            self.cached_v = v
        
        # Create causal mask if needed
        if mask is None:
            # If using cache, adjust mask to account for the full sequence length
            full_seq_len = k.size(2)
            # For cached version, we need to adjust the mask to allow attention to all past tokens
            if use_cache and self.cached_k is not None:
                # Create a mask where current tokens can attend to all previous tokens
                # Current sequence position is at seq_len
                seq_position = seq_len
                # Create a mask that allows each token to see itself and all previous tokens
                mask = torch.ones(seq_len, full_seq_len).to(x.device)
                # Make it causal by setting future positions to 0
                mask[:, seq_position:] = 0
            else:
                # Standard causal mask for the full sequence
                mask = torch.tril(torch.ones(full_seq_len, full_seq_len)).to(x.device)
            
            mask = mask.unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len, seq_len)

        # Use the attention module directly
        output = self.attention(q, k, v, mask)  # (batch_size, n_heads, seq_len, head_size)
        
        # Combine heads
        output = self.combine_heads(output)  # (batch_size, seq_len, d_model)
        past_key_values = (k, v)
        if use_cache:
            return self.dropout(self.out(output)) , past_key_values
        else:
            return self.dropout(self.out(output))
    
    def clear_cache(self):
        self.cached_k = None
        self.cached_v = None
    
class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1, use_cache=False):
        super().__init__()
        self.masked_mha = MultiHeadAttention(d_model, n_heads, dropout, cache=use_cache)
        self.layer_norm1 = nn.LayerNorm(d_model)
        
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),  
            nn.Linear(4 * d_model, d_model),
            nn.Dropout(dropout)
        )
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, use_cache=False, past_key_values=None):

        # Self-attention with residual connection and layer normalization
        residual = x
        x = self.layer_norm1(x)  # Pre-norm architecture
        if use_cache and past_key_values is not None:
            x, past_key_values = self.masked_mha(x, use_cache=use_cache, past_key_values=past_key_values)
        else:
            x = self.masked_mha(x)
            
        x = residual + x  # Residual connection
        
        # Feed forward with residual connection and layer normalization
        residual = x
        x = self.layer_norm2(x)  # Pre-norm architecture
        x = self.feed_forward(x)
        x = residual + x  # Residual connection
        if use_cache:
            return x , past_key_values
        else:
            return x
    
    def clear_cache(self):
        self.masked_mha.clear_cache()

class iGPT(nn.Module):
    def __init__(self, vocab_size, context_length, d_model, n_heads, n_layers, dropout=0.1, use_cache=False):
        super().__init__()
        self.vocab_size = vocab_size
        self.context_length = context_length
        self.d_model = d_model
        self.n_heads = n_heads  
        self.n_layers = n_layers
        self.dropout = dropout
        self.use_cache = use_cache
        
        # Token embedding
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        
        # Positional embedding (learned, as per iGPT specs)
        self.position_embedding = nn.Embedding(context_length, d_model)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
        
        # Stack of decoder layers
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, n_heads, dropout, use_cache=use_cache) 
            for _ in range(n_layers)
        ])
        
        # Final layer norm
        self.layer_norm = nn.LayerNorm(d_model)
        
        # Output projection
        self.output_projection = nn.Linear(d_model, vocab_size)

    def forward(self, x, past_key_values=None, use_cache=False):
        # x shape: (batch_size, seq_len)
        batch_size, seq_len = x.shape
        device = x.device
        
        # Create position indices
        positions = torch.arange(0, seq_len, dtype=torch.long, device=device).unsqueeze(0).expand(batch_size, -1)
        
        # Get embeddings
        token_emb = self.token_embedding(x)  # (batch_size, seq_len, d_model)
        pos_emb = self.position_embedding(positions)  # (batch_size, seq_len, d_model)
        
        # Combine embeddings
        x = token_emb + pos_emb  # (batch_size, seq_len, d_model)
        x = self.dropout(x)
        
        # Apply decoder layers
        past_key_values = None
        for layer in self.decoder_layers:
            if use_cache:
                x, past_key_values = layer(x, use_cache=use_cache, past_key_values=past_key_values)
            else:
                x = layer(x)
        
        # Apply final layer norm
        x = self.layer_norm(x)  # (batch_size, seq_len, d_model)
        
        # Project to vocabulary
        logits = self.output_projection(x)  # (batch_size, seq_len, vocab_size)
        if use_cache:
            return logits, past_key_values
        else:
            return logits
    
    def clear_cache(self):
        for layer in self.decoder_layers:
            layer.clear_cache()
In [11]:
def test_igpt():
    # Define dummy parameters
    vocab_size = 10
    context_length = 20
    d_model = 128
    n_heads = 4
    n_layers = 2
    batch_size = 5
    seq_len = context_length

    # Create a dummy input tensor
    dummy_input = torch.randint(0, vocab_size, (batch_size, seq_len))

    # Initialize the iGPT model
    model = iGPT(vocab_size, context_length, d_model, n_heads, n_layers)

    # Test token embedding
    token_emb = model.token_embedding(dummy_input)
    print("Token embedding shape:", token_emb.shape)
    assert token_emb.shape == (batch_size, seq_len, d_model), "Token embedding shape mismatch!"
    
    # Test position embedding
    positions = torch.arange(0, seq_len, dtype=torch.long, device=dummy_input.device).unsqueeze(0).expand(batch_size, -1)
    pos_emb = model.position_embedding(positions)
    print("Position embedding shape:", pos_emb.shape)
    assert pos_emb.shape == (batch_size, seq_len, d_model), "Position embedding shape mismatch!"
    
    # Test each decoder layer
    x = token_emb + pos_emb
    x = model.dropout(x)
    for i, layer in enumerate(model.decoder_layers):
        x_before = x.clone()
        x = layer(x)
        print(f"Decoder layer {i} output shape:", x.shape)
        assert x.shape == (batch_size, seq_len, d_model), f"Decoder layer {i} output shape mismatch!"
        # Check that the layer actually modified the input
        assert not torch.allclose(x, x_before), f"Decoder layer {i} did not modify the input!"
    
    # Test final layer norm
    x_before = x.clone()
    x = model.layer_norm(x)
    print("Layer norm output shape:", x.shape)
    assert x.shape == (batch_size, seq_len, d_model), "Layer norm output shape mismatch!"
    
    # Test output projection
    logits = model.output_projection(x)
    print("Output logits shape:", logits.shape)
    assert logits.shape == (batch_size, seq_len, vocab_size), "Output logits shape mismatch!"

    # Full forward pass
    output = model(dummy_input)
    print("Final output shape:", output.shape)
    assert output.shape == (batch_size, seq_len, vocab_size), "Final output shape mismatch!"

    print("iGPT model test passed! All layers are implemented correctly.")

# Run the test
test_igpt()
Token embedding shape: torch.Size([5, 20, 128])
Position embedding shape: torch.Size([5, 20, 128])
Decoder layer 0 output shape: torch.Size([5, 20, 128])
Decoder layer 1 output shape: torch.Size([5, 20, 128])
Layer norm output shape: torch.Size([5, 20, 128])
Output logits shape: torch.Size([5, 20, 10])
Final output shape: torch.Size([5, 20, 10])
iGPT model test passed! All layers are implemented correctly.
In [12]:
def generate_samples(model, sequence_length, vocab_size, image_shape, device, num_samples=100, use_cache=False, test_mode=False):
    """
    Generates samples from the trained model.
    
    Args:
        model: The trained iGPT model
        sequence_length: Length of token sequences including <bos>
        vocab_size: Size of vocabulary
        image_shape: (H, W, C) tuple specifying image dimensions
        device: Device to run generation on
        num_samples: Number of samples to generate
        use_cache: Whether to use caching for faster sampling
        test_mode: If True, only generate first 5 samples and fill the rest with blank images
        
    Returns:
        Numpy array of generated samples with shape (num_samples, H, W, C)
        and a list of generation times
    """
    H, W, C = image_shape
    model.eval()
    samples = []
    import time
    time_list = []
    
    # Determine how many samples to actually generate
    samples_to_generate = 5 if test_mode else num_samples
    
    with torch.no_grad():
        for i in range(num_samples):
            if test_mode and i >= samples_to_generate:
                # In test mode, fill remaining samples with blank images
                if C == 3:
                    blank_sample = np.zeros((H, W, C), dtype=np.uint8)
                else:
                    blank_sample = np.zeros((H, W, C), dtype=np.uint8)
                samples.append(blank_sample)
                time_list.append(0.0)  # No time spent on blank images
                continue
                
            start_time = time.time()
            
            # Start with just the <bos> token
            sample = torch.zeros(1, sequence_length, dtype=torch.long, device=device)
            sample[:, 0] = 0  # <bos> token
            
            # Cache for key-value pairs if using caching
            past_key_values = None
            
            # Autoregressive generation - one token at a time
            for i in range(1, sequence_length):
                if use_cache and past_key_values is not None:
                    # Only need to process the new token with cached key-values
                    logits, past_key_values = model(sample[:, i-1:i], past_key_values=past_key_values, use_cache=True)
                    # print(f"logits shape: {logits.shape}")
                    logits = logits[:, -1, :]  # Get prediction for current position
                else:
                    # Process the entire sequence up to current position
                    if use_cache:
                        logits, past_key_values = model(sample[:, :i], use_cache=True)
                        logits = logits[:, -1, :]  # Get prediction for current position
                    else:
                        logits = model(sample)
                        logits = logits[:, i-1, :] # Get prediction for current position
                # print(f"past_key_values: {past_key_values}")
                probs = F.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, 1).squeeze(-1)
                sample[:, i] = next_token
            
            # Convert tokens back to image format (remove <bos> token)
            sample_tokens = sample[:, 1:].cpu().numpy().reshape(H, W)
            
            # Convert single tokens back to RGB values if needed
            if C == 3:
                sample_rgb = np.zeros((H, W, C), dtype=np.uint8)
                sample_rgb[:, :, 0] = sample_tokens // 16  # R = token // 16
                sample_rgb[:, :, 1] = (sample_tokens % 16) // 4  # G = (token % 16) // 4
                sample_rgb[:, :, 2] = sample_tokens % 4  # B = token % 4
                samples.append(sample_rgb)
            else:
                samples.append(sample_tokens.reshape(H, W, C))
            
            end_time = time.time()
            time_list.append(end_time - start_time)
    
    return np.array(samples), np.array(time_list)
In [13]:
import math

def create_dataset(data, image_shape, batch_size):
    """
    Converts image data to token sequences and creates PyTorch DataLoader.
    
    Args:
        data: A (n_samples, H, W, C) uint8 numpy array of images
        image_shape: (H, W, C) tuple specifying image dimensions
        batch_size: Batch size for DataLoader
        
    Returns:
        DataLoader object with tokenized image sequences
    """
    H, W, C = image_shape
    
    # Convert RGB pixels to single tokens (4 values per channel = 64 possible values)
    # Shape: (n_samples, H, W, C) -> (n_samples, H, W)
    if C == 3:
        # Convert RGB values to a single token: r*16 + g*4 + b
        # Each channel has values in {0,1,2,3}, so we can encode as a single number 0-63
        data_tokens = (data[:,:,:,0] * 16 + data[:,:,:,1] * 4 + data[:,:,:,2])
    else:
        # For grayscale, just use the values directly
        data_tokens = data.reshape(-1, H, W)
    
    # Flatten spatial dimensions to create sequences
    # Shape: (n_samples, H, W) -> (n_samples, H*W)
    data_flat = data_tokens.reshape(-1, H * W)
    
    # Convert to PyTorch tensors
    dataset = torch.utils.data.TensorDataset(torch.tensor(data_flat, dtype=torch.long))
    
    # Create data loader
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

def evaluate_model(model, data_loader, sequence_length, vocab_size, device):
    """
    Evaluates model performance on a dataset.
    
    Args:
        model: The iGPT model
        data_loader: DataLoader containing tokenized images
        sequence_length: Length of token sequences including <bos>
        vocab_size: Size of vocabulary
        device: Device to run evaluation on
        
    Returns:
        Average loss (negative log-likelihood) per dimension
    """
    model.eval()
    total_loss = 0
    total_samples = 0
    
    with torch.no_grad():
        for (data,) in data_loader:
            data = data.to(device)  # Shape: (batch_size, sequence_length-1)
            batch_size = data.size(0)
            
            # Create input with <bos> token (0) at the beginning
            # Shape: (batch_size, sequence_length)
            input_seq = torch.zeros(batch_size, sequence_length, dtype=torch.long, device=device)
            input_seq[:, 0] = 0  # <bos> token
            input_seq[:, 1:] = data  # actual image data
            
            # Create targets (the image tokens to predict)
            # Shape: (batch_size, sequence_length-1)
            targets = data
            
            # Forward pass
            # Shape: (batch_size, sequence_length, vocab_size) -> (batch_size, sequence_length-1, vocab_size)
            logits = model(input_seq)[:, :-1, :]  # Remove last position's prediction
            
            # Compute loss
            loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1), reduction='sum')
            
            total_loss += loss.item()
            total_samples += batch_size * (sequence_length - 1)
    
    return total_loss / total_samples



def train_igpt(model, train_loader, test_loader, sequence_length, vocab_size, 
               device, num_epochs, learning_rate):
    """
    Trains the iGPT model.
    
    Args:
        model: The iGPT model to train
        train_loader: DataLoader for training data
        test_loader: DataLoader for test data
        sequence_length: Length of token sequences including <bos>
        vocab_size: Size of vocabulary
        device: Device to train on
        num_epochs: Number of training epochs
        learning_rate: Initial learning rate
        
    Returns:
        train_losses: Array of training losses per minibatch
        test_losses: Array of test losses per epoch
    """
    # Initialize optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    # Learning rate scheduler with warmup and cosine decay
    warmup_steps = 1000
    total_steps = len(train_loader) * num_epochs
    
    def lr_lambda(step):
        if step < warmup_steps:
            return step / warmup_steps
        else:
            decay_ratio = (step - warmup_steps) / (total_steps - warmup_steps)
            return 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    
    # Initialize arrays to store losses
    train_losses = []
    test_losses = [evaluate_model(model, test_loader, sequence_length, vocab_size, device)]
    
    # Training loop
    for epoch in range(num_epochs):
        model.train()
        epoch_losses = []
        
        for batch_idx, (data,) in enumerate(train_loader):
            data = data.to(device)  # Shape: (batch_size, sequence_length-1)
            batch_size = data.size(0)
            
            # Create input with <bos> token (0) at the beginning
            # Shape: (batch_size, sequence_length)
            input_seq = torch.zeros(batch_size, sequence_length, dtype=torch.long, device=device)
            input_seq[:, 0] = 0  # <bos> token
            input_seq[:, 1:] = data  # actual image data 
            
            # Create targets (the image tokens to predict)
            # Shape: (batch_size, sequence_length-1)
            targets = data
            
            # Forward pass
            # Shape: (batch_size, sequence_length, vocab_size) -> (batch_size, sequence_length-1, vocab_size)
            logits = model(input_seq)[:, :-1, :]  # Remove last position's prediction (don't predict <eos>)
            
            # Compute loss
            loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1))
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            # Record loss
            train_losses.append(loss.item())
            epoch_losses.append(loss.item())
            
            if batch_idx % 50 == 0:
                print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")
        
        # Evaluate on test set after each epoch
        test_loss = evaluate_model(model, test_loader, sequence_length, vocab_size, device)
        test_losses.append(test_loss)
        print(f"Epoch {epoch+1}/{num_epochs} completed. Test Loss: {test_loss:.4f}")
    
    return np.array(train_losses), np.array(test_losses)

def q3_a(train_data, test_data, image_shape, dset_id):
    """
    train_data: A (n_train, H, W, 1) uint8 numpy array of color images with values in {0, 1}
    test_data: A (n_test, H, W, 1) uint8 numpy array of color images with values in {0, 1}
    image_shape: (H, W, 1), height, width, and # of channels of the image
    dset_id: An identifying number of which dataset is given (1 or 2). Most likely
             used to set different hyperparameters for different datasets

    Returns
    - a (# of training iterations,) numpy array of train_losses evaluated every minibatch
    - a (# of epochs + 1,) numpy array of test_losses evaluated once at initialization and after each epoch
    - a numpy array of size (100, H, W, 1) of samples with values in {0, 1}
    """
    # Hyperparameters
    batch_size = 64
    learning_rate = 1e-3
    num_epochs = 15
    
    # Model parameters as recommended in the instructions
    d_model = 128
    n_heads = 4
    n_layers = 2
    
    # Determine sequence length and vocabulary size
    H, W, C = image_shape
    sequence_length = H * W * C + 1  # +1 for <bos> token
    vocab_size = 2  # Binary images with values in {0, 1}
    
    # Create datasets and data loaders
    train_loader = create_dataset(train_data, image_shape, batch_size)
    test_loader = create_dataset(test_data, image_shape, batch_size)
    
    # Initialize model and move to device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = iGPT(vocab_size, sequence_length, d_model, n_heads, n_layers).to(device)
    
    # Train the model
    train_losses, test_losses = train_igpt(model, train_loader, test_loader, 
                                          sequence_length, vocab_size, device,
                                          num_epochs, learning_rate)
    
    # Generate samples
    # save the model
    torch.save(model, 'model_no_cache.pth')
    samples , _= generate_samples(model, sequence_length, vocab_size, image_shape, device)
    
    return train_losses, test_losses, samples

Results¶

Once you've implemented q3_a, execute the cells below to visualize and save your results

In [11]:
q3ab_save_results(1, 'a', q3_a)
data_dir:  /home/nghiaph/workspace/deepul/homeworks/hw1/data
Epoch 1/15, Batch 0/164, Loss: 0.7565
Epoch 1/15, Batch 50/164, Loss: 0.3444
Epoch 1/15, Batch 100/164, Loss: 0.2132
Epoch 1/15, Batch 150/164, Loss: 0.2036
Epoch 1/15 completed. Test Loss: 0.1965
Epoch 2/15, Batch 0/164, Loss: 0.1971
Epoch 2/15, Batch 50/164, Loss: 0.1753
Epoch 2/15, Batch 100/164, Loss: 0.1454
Epoch 2/15, Batch 150/164, Loss: 0.1234
Epoch 2/15 completed. Test Loss: 0.1108
Epoch 3/15, Batch 0/164, Loss: 0.1324
Epoch 3/15, Batch 50/164, Loss: 0.1107
Epoch 3/15, Batch 100/164, Loss: 0.0993
Epoch 3/15, Batch 150/164, Loss: 0.1016
Epoch 3/15 completed. Test Loss: 0.0928
Epoch 4/15, Batch 0/164, Loss: 0.0973
Epoch 4/15, Batch 50/164, Loss: 0.0934
Epoch 4/15, Batch 100/164, Loss: 0.0873
Epoch 4/15, Batch 150/164, Loss: 0.0844
Epoch 4/15 completed. Test Loss: 0.0796
Epoch 5/15, Batch 0/164, Loss: 0.0899
Epoch 5/15, Batch 50/164, Loss: 0.0852
Epoch 5/15, Batch 100/164, Loss: 0.0808
Epoch 5/15, Batch 150/164, Loss: 0.0717
Epoch 5/15 completed. Test Loss: 0.0698
Epoch 6/15, Batch 0/164, Loss: 0.0789
Epoch 6/15, Batch 50/164, Loss: 0.0682
Epoch 6/15, Batch 100/164, Loss: 0.0645
Epoch 6/15, Batch 150/164, Loss: 0.0703
Epoch 6/15 completed. Test Loss: 0.0632
Epoch 7/15, Batch 0/164, Loss: 0.0690
Epoch 7/15, Batch 50/164, Loss: 0.0636
Epoch 7/15, Batch 100/164, Loss: 0.0620
Epoch 7/15, Batch 150/164, Loss: 0.0666
Epoch 7/15 completed. Test Loss: 0.0578
Epoch 8/15, Batch 0/164, Loss: 0.0619
Epoch 8/15, Batch 50/164, Loss: 0.0614
Epoch 8/15, Batch 100/164, Loss: 0.0648
Epoch 8/15, Batch 150/164, Loss: 0.0571
Epoch 8/15 completed. Test Loss: 0.0538
Epoch 9/15, Batch 0/164, Loss: 0.0603
Epoch 9/15, Batch 50/164, Loss: 0.0560
Epoch 9/15, Batch 100/164, Loss: 0.0531
Epoch 9/15, Batch 150/164, Loss: 0.0582
Epoch 9/15 completed. Test Loss: 0.0515
Epoch 10/15, Batch 0/164, Loss: 0.0600
Epoch 10/15, Batch 50/164, Loss: 0.0551
Epoch 10/15, Batch 100/164, Loss: 0.0550
Epoch 10/15, Batch 150/164, Loss: 0.0522
Epoch 10/15 completed. Test Loss: 0.0493
Epoch 11/15, Batch 0/164, Loss: 0.0491
Epoch 11/15, Batch 50/164, Loss: 0.0547
Epoch 11/15, Batch 100/164, Loss: 0.0544
Epoch 11/15, Batch 150/164, Loss: 0.0545
Epoch 11/15 completed. Test Loss: 0.0481
Epoch 12/15, Batch 0/164, Loss: 0.0584
Epoch 12/15, Batch 50/164, Loss: 0.0540
Epoch 12/15, Batch 100/164, Loss: 0.0552
Epoch 12/15, Batch 150/164, Loss: 0.0537
Epoch 12/15 completed. Test Loss: 0.0473
Epoch 13/15, Batch 0/164, Loss: 0.0509
Epoch 13/15, Batch 50/164, Loss: 0.0541
Epoch 13/15, Batch 100/164, Loss: 0.0493
Epoch 13/15, Batch 150/164, Loss: 0.0562
Epoch 13/15 completed. Test Loss: 0.0467
Epoch 14/15, Batch 0/164, Loss: 0.0487
Epoch 14/15, Batch 50/164, Loss: 0.0509
Epoch 14/15, Batch 100/164, Loss: 0.0518
Epoch 14/15, Batch 150/164, Loss: 0.0517
Epoch 14/15 completed. Test Loss: 0.0465
Epoch 15/15, Batch 0/164, Loss: 0.0497
Epoch 15/15, Batch 50/164, Loss: 0.0526
Epoch 15/15, Batch 100/164, Loss: 0.0508
Epoch 15/15, Batch 150/164, Loss: 0.0521
Epoch 15/15 completed. Test Loss: 0.0465
Final Test Loss: 0.0465
No description has been provided for this image
samples shape:  (100, 20, 20, 1)
No description has been provided for this image
In [12]:
q3ab_save_results(2, 'a', q3_a)
data_dir:  /home/nghiaph/workspace/deepul/homeworks/hw1/data
Epoch 1/15, Batch 0/938, Loss: 0.6212
Epoch 1/15, Batch 50/938, Loss: 0.2837
Epoch 1/15, Batch 100/938, Loss: 0.2267
Epoch 1/15, Batch 150/938, Loss: 0.2190
Epoch 1/15, Batch 200/938, Loss: 0.2150
Epoch 1/15, Batch 250/938, Loss: 0.2095
Epoch 1/15, Batch 300/938, Loss: 0.1933
Epoch 1/15, Batch 350/938, Loss: 0.1886
Epoch 1/15, Batch 400/938, Loss: 0.1882
Epoch 1/15, Batch 450/938, Loss: 0.1887
Epoch 1/15, Batch 500/938, Loss: 0.1821
Epoch 1/15, Batch 550/938, Loss: 0.1730
Epoch 1/15, Batch 600/938, Loss: 0.1748
Epoch 1/15, Batch 650/938, Loss: 0.1798
Epoch 1/15, Batch 700/938, Loss: 0.1679
Epoch 1/15, Batch 750/938, Loss: 0.1672
Epoch 1/15, Batch 800/938, Loss: 0.1667
Epoch 1/15, Batch 850/938, Loss: 0.1572
Epoch 1/15, Batch 900/938, Loss: 0.1585
Epoch 1/15 completed. Test Loss: 0.1423
Epoch 2/15, Batch 0/938, Loss: 0.1555
Epoch 2/15, Batch 50/938, Loss: 0.1412
Epoch 2/15, Batch 100/938, Loss: 0.1427
Epoch 2/15, Batch 150/938, Loss: 0.1342
Epoch 2/15, Batch 200/938, Loss: 0.1320
Epoch 2/15, Batch 250/938, Loss: 0.1263
Epoch 2/15, Batch 300/938, Loss: 0.1251
Epoch 2/15, Batch 350/938, Loss: 0.1259
Epoch 2/15, Batch 400/938, Loss: 0.1237
Epoch 2/15, Batch 450/938, Loss: 0.1127
Epoch 2/15, Batch 500/938, Loss: 0.1134
Epoch 2/15, Batch 550/938, Loss: 0.1098
Epoch 2/15, Batch 600/938, Loss: 0.1132
Epoch 2/15, Batch 650/938, Loss: 0.1166
Epoch 2/15, Batch 700/938, Loss: 0.1157
Epoch 2/15, Batch 750/938, Loss: 0.1137
Epoch 2/15, Batch 800/938, Loss: 0.1190
Epoch 2/15, Batch 850/938, Loss: 0.1151
Epoch 2/15, Batch 900/938, Loss: 0.1169
Epoch 2/15 completed. Test Loss: 0.1027
Epoch 3/15, Batch 0/938, Loss: 0.1138
Epoch 3/15, Batch 50/938, Loss: 0.1075
Epoch 3/15, Batch 100/938, Loss: 0.1063
Epoch 3/15, Batch 150/938, Loss: 0.1028
Epoch 3/15, Batch 200/938, Loss: 0.1101
Epoch 3/15, Batch 250/938, Loss: 0.1131
Epoch 3/15, Batch 300/938, Loss: 0.1020
Epoch 3/15, Batch 350/938, Loss: 0.1067
Epoch 3/15, Batch 400/938, Loss: 0.1033
Epoch 3/15, Batch 450/938, Loss: 0.1099
Epoch 3/15, Batch 500/938, Loss: 0.1051
Epoch 3/15, Batch 550/938, Loss: 0.1086
Epoch 3/15, Batch 600/938, Loss: 0.1006
Epoch 3/15, Batch 650/938, Loss: 0.1089
Epoch 3/15, Batch 700/938, Loss: 0.1042
Epoch 3/15, Batch 750/938, Loss: 0.1095
Epoch 3/15, Batch 800/938, Loss: 0.1015
Epoch 3/15, Batch 850/938, Loss: 0.0945
Epoch 3/15, Batch 900/938, Loss: 0.1031
Epoch 3/15 completed. Test Loss: 0.0942
Epoch 4/15, Batch 0/938, Loss: 0.1038
Epoch 4/15, Batch 50/938, Loss: 0.0974
Epoch 4/15, Batch 100/938, Loss: 0.0976
Epoch 4/15, Batch 150/938, Loss: 0.1023
Epoch 4/15, Batch 200/938, Loss: 0.1002
Epoch 4/15, Batch 250/938, Loss: 0.0990
Epoch 4/15, Batch 300/938, Loss: 0.1043
Epoch 4/15, Batch 350/938, Loss: 0.1012
Epoch 4/15, Batch 400/938, Loss: 0.1005
Epoch 4/15, Batch 450/938, Loss: 0.1032
Epoch 4/15, Batch 500/938, Loss: 0.0986
Epoch 4/15, Batch 550/938, Loss: 0.0998
Epoch 4/15, Batch 600/938, Loss: 0.0955
Epoch 4/15, Batch 650/938, Loss: 0.0958
Epoch 4/15, Batch 700/938, Loss: 0.0912
Epoch 4/15, Batch 750/938, Loss: 0.0970
Epoch 4/15, Batch 800/938, Loss: 0.0925
Epoch 4/15, Batch 850/938, Loss: 0.0960
Epoch 4/15, Batch 900/938, Loss: 0.0965
Epoch 4/15 completed. Test Loss: 0.0903
Epoch 5/15, Batch 0/938, Loss: 0.0977
Epoch 5/15, Batch 50/938, Loss: 0.0946
Epoch 5/15, Batch 100/938, Loss: 0.0898
Epoch 5/15, Batch 150/938, Loss: 0.0943
Epoch 5/15, Batch 200/938, Loss: 0.0954
Epoch 5/15, Batch 250/938, Loss: 0.0882
Epoch 5/15, Batch 300/938, Loss: 0.0968
Epoch 5/15, Batch 350/938, Loss: 0.0936
Epoch 5/15, Batch 400/938, Loss: 0.0939
Epoch 5/15, Batch 450/938, Loss: 0.0911
Epoch 5/15, Batch 500/938, Loss: 0.0911
Epoch 5/15, Batch 550/938, Loss: 0.0944
Epoch 5/15, Batch 600/938, Loss: 0.0946
Epoch 5/15, Batch 650/938, Loss: 0.0917
Epoch 5/15, Batch 700/938, Loss: 0.0940
Epoch 5/15, Batch 750/938, Loss: 0.0956
Epoch 5/15, Batch 800/938, Loss: 0.0923
Epoch 5/15, Batch 850/938, Loss: 0.0938
Epoch 5/15, Batch 900/938, Loss: 0.0873
Epoch 5/15 completed. Test Loss: 0.0880
Epoch 6/15, Batch 0/938, Loss: 0.0955
Epoch 6/15, Batch 50/938, Loss: 0.0928
Epoch 6/15, Batch 100/938, Loss: 0.0969
Epoch 6/15, Batch 150/938, Loss: 0.0914
Epoch 6/15, Batch 200/938, Loss: 0.0913
Epoch 6/15, Batch 250/938, Loss: 0.0922
Epoch 6/15, Batch 300/938, Loss: 0.0900
Epoch 6/15, Batch 350/938, Loss: 0.0939
Epoch 6/15, Batch 400/938, Loss: 0.0883
Epoch 6/15, Batch 450/938, Loss: 0.0998
Epoch 6/15, Batch 500/938, Loss: 0.0911
Epoch 6/15, Batch 550/938, Loss: 0.0918
Epoch 6/15, Batch 600/938, Loss: 0.0913
Epoch 6/15, Batch 650/938, Loss: 0.0941
Epoch 6/15, Batch 700/938, Loss: 0.0936
Epoch 6/15, Batch 750/938, Loss: 0.0923
Epoch 6/15, Batch 800/938, Loss: 0.0869
Epoch 6/15, Batch 850/938, Loss: 0.0854
Epoch 6/15, Batch 900/938, Loss: 0.0890
Epoch 6/15 completed. Test Loss: 0.0852
Epoch 7/15, Batch 0/938, Loss: 0.0907
Epoch 7/15, Batch 50/938, Loss: 0.0869
Epoch 7/15, Batch 100/938, Loss: 0.0891
Epoch 7/15, Batch 150/938, Loss: 0.0905
Epoch 7/15, Batch 200/938, Loss: 0.0886
Epoch 7/15, Batch 250/938, Loss: 0.0889
Epoch 7/15, Batch 300/938, Loss: 0.0960
Epoch 7/15, Batch 350/938, Loss: 0.0868
Epoch 7/15, Batch 400/938, Loss: 0.0964
Epoch 7/15, Batch 450/938, Loss: 0.0897
Epoch 7/15, Batch 500/938, Loss: 0.0885
Epoch 7/15, Batch 550/938, Loss: 0.0915
Epoch 7/15, Batch 600/938, Loss: 0.0906
Epoch 7/15, Batch 650/938, Loss: 0.0864
Epoch 7/15, Batch 700/938, Loss: 0.0852
Epoch 7/15, Batch 750/938, Loss: 0.0880
Epoch 7/15, Batch 800/938, Loss: 0.0888
Epoch 7/15, Batch 850/938, Loss: 0.0813
Epoch 7/15, Batch 900/938, Loss: 0.0898
Epoch 7/15 completed. Test Loss: 0.0835
Epoch 8/15, Batch 0/938, Loss: 0.0873
Epoch 8/15, Batch 50/938, Loss: 0.0872
Epoch 8/15, Batch 100/938, Loss: 0.0908
Epoch 8/15, Batch 150/938, Loss: 0.0904
Epoch 8/15, Batch 200/938, Loss: 0.0869
Epoch 8/15, Batch 250/938, Loss: 0.0844
Epoch 8/15, Batch 300/938, Loss: 0.0903
Epoch 8/15, Batch 350/938, Loss: 0.0834
Epoch 8/15, Batch 400/938, Loss: 0.0842
Epoch 8/15, Batch 450/938, Loss: 0.0839
Epoch 8/15, Batch 500/938, Loss: 0.0870
Epoch 8/15, Batch 550/938, Loss: 0.0879
Epoch 8/15, Batch 600/938, Loss: 0.0842
Epoch 8/15, Batch 650/938, Loss: 0.0883
Epoch 8/15, Batch 700/938, Loss: 0.0861
Epoch 8/15, Batch 750/938, Loss: 0.0850
Epoch 8/15, Batch 800/938, Loss: 0.0797
Epoch 8/15, Batch 850/938, Loss: 0.0834
Epoch 8/15, Batch 900/938, Loss: 0.0931
Epoch 8/15 completed. Test Loss: 0.0828
Epoch 9/15, Batch 0/938, Loss: 0.0853
Epoch 9/15, Batch 50/938, Loss: 0.0869
Epoch 9/15, Batch 100/938, Loss: 0.0900
Epoch 9/15, Batch 150/938, Loss: 0.0850
Epoch 9/15, Batch 200/938, Loss: 0.0873
Epoch 9/15, Batch 250/938, Loss: 0.0848
Epoch 9/15, Batch 300/938, Loss: 0.0917
Epoch 9/15, Batch 350/938, Loss: 0.0872
Epoch 9/15, Batch 400/938, Loss: 0.0863
Epoch 9/15, Batch 450/938, Loss: 0.0914
Epoch 9/15, Batch 500/938, Loss: 0.0878
Epoch 9/15, Batch 550/938, Loss: 0.0839
Epoch 9/15, Batch 600/938, Loss: 0.0850
Epoch 9/15, Batch 650/938, Loss: 0.0913
Epoch 9/15, Batch 700/938, Loss: 0.0865
Epoch 9/15, Batch 750/938, Loss: 0.0877
Epoch 9/15, Batch 800/938, Loss: 0.0853
Epoch 9/15, Batch 850/938, Loss: 0.0892
Epoch 9/15, Batch 900/938, Loss: 0.0905
Epoch 9/15 completed. Test Loss: 0.0817
Epoch 10/15, Batch 0/938, Loss: 0.0804
Epoch 10/15, Batch 50/938, Loss: 0.0857
Epoch 10/15, Batch 100/938, Loss: 0.0792
Epoch 10/15, Batch 150/938, Loss: 0.0889
Epoch 10/15, Batch 200/938, Loss: 0.0847
Epoch 10/15, Batch 250/938, Loss: 0.0863
Epoch 10/15, Batch 300/938, Loss: 0.0876
Epoch 10/15, Batch 350/938, Loss: 0.0854
Epoch 10/15, Batch 400/938, Loss: 0.0824
Epoch 10/15, Batch 450/938, Loss: 0.0854
Epoch 10/15, Batch 500/938, Loss: 0.0889
Epoch 10/15, Batch 550/938, Loss: 0.0803
Epoch 10/15, Batch 600/938, Loss: 0.0826
Epoch 10/15, Batch 650/938, Loss: 0.0901
Epoch 10/15, Batch 700/938, Loss: 0.0889
Epoch 10/15, Batch 750/938, Loss: 0.0833
Epoch 10/15, Batch 800/938, Loss: 0.0811
Epoch 10/15, Batch 850/938, Loss: 0.0868
Epoch 10/15, Batch 900/938, Loss: 0.0864
Epoch 10/15 completed. Test Loss: 0.0806
Epoch 11/15, Batch 0/938, Loss: 0.0873
Epoch 11/15, Batch 50/938, Loss: 0.0871
Epoch 11/15, Batch 100/938, Loss: 0.0860
Epoch 11/15, Batch 150/938, Loss: 0.0835
Epoch 11/15, Batch 200/938, Loss: 0.0884
Epoch 11/15, Batch 250/938, Loss: 0.0830
Epoch 11/15, Batch 300/938, Loss: 0.0852
Epoch 11/15, Batch 350/938, Loss: 0.0793
Epoch 11/15, Batch 400/938, Loss: 0.0858
Epoch 11/15, Batch 450/938, Loss: 0.0822
Epoch 11/15, Batch 500/938, Loss: 0.0810
Epoch 11/15, Batch 550/938, Loss: 0.0882
Epoch 11/15, Batch 600/938, Loss: 0.0854
Epoch 11/15, Batch 650/938, Loss: 0.0831
Epoch 11/15, Batch 700/938, Loss: 0.0854
Epoch 11/15, Batch 750/938, Loss: 0.0891
Epoch 11/15, Batch 800/938, Loss: 0.0845
Epoch 11/15, Batch 850/938, Loss: 0.0868
Epoch 11/15, Batch 900/938, Loss: 0.0889
Epoch 11/15 completed. Test Loss: 0.0800
Epoch 12/15, Batch 0/938, Loss: 0.0884
Epoch 12/15, Batch 50/938, Loss: 0.0879
Epoch 12/15, Batch 100/938, Loss: 0.0860
Epoch 12/15, Batch 150/938, Loss: 0.0785
Epoch 12/15, Batch 200/938, Loss: 0.0811
Epoch 12/15, Batch 250/938, Loss: 0.0837
Epoch 12/15, Batch 300/938, Loss: 0.0837
Epoch 12/15, Batch 350/938, Loss: 0.0862
Epoch 12/15, Batch 400/938, Loss: 0.0783
Epoch 12/15, Batch 450/938, Loss: 0.0838
Epoch 12/15, Batch 500/938, Loss: 0.0764
Epoch 12/15, Batch 550/938, Loss: 0.0862
Epoch 12/15, Batch 600/938, Loss: 0.0848
Epoch 12/15, Batch 650/938, Loss: 0.0791
Epoch 12/15, Batch 700/938, Loss: 0.0868
Epoch 12/15, Batch 750/938, Loss: 0.0830
Epoch 12/15, Batch 800/938, Loss: 0.0882
Epoch 12/15, Batch 850/938, Loss: 0.0853
Epoch 12/15, Batch 900/938, Loss: 0.0819
Epoch 12/15 completed. Test Loss: 0.0795
Epoch 13/15, Batch 0/938, Loss: 0.0900
Epoch 13/15, Batch 50/938, Loss: 0.0815
Epoch 13/15, Batch 100/938, Loss: 0.0791
Epoch 13/15, Batch 150/938, Loss: 0.0890
Epoch 13/15, Batch 200/938, Loss: 0.0841
Epoch 13/15, Batch 250/938, Loss: 0.0842
Epoch 13/15, Batch 300/938, Loss: 0.0799
Epoch 13/15, Batch 350/938, Loss: 0.0808
Epoch 13/15, Batch 400/938, Loss: 0.0847
Epoch 13/15, Batch 450/938, Loss: 0.0811
Epoch 13/15, Batch 500/938, Loss: 0.0836
Epoch 13/15, Batch 550/938, Loss: 0.0857
Epoch 13/15, Batch 600/938, Loss: 0.0791
Epoch 13/15, Batch 650/938, Loss: 0.0851
Epoch 13/15, Batch 700/938, Loss: 0.0869
Epoch 13/15, Batch 750/938, Loss: 0.0825
Epoch 13/15, Batch 800/938, Loss: 0.0871
Epoch 13/15, Batch 850/938, Loss: 0.0839
Epoch 13/15, Batch 900/938, Loss: 0.0818
Epoch 13/15 completed. Test Loss: 0.0795
Epoch 14/15, Batch 0/938, Loss: 0.0849
Epoch 14/15, Batch 50/938, Loss: 0.0858
Epoch 14/15, Batch 100/938, Loss: 0.0865
Epoch 14/15, Batch 150/938, Loss: 0.0852
Epoch 14/15, Batch 200/938, Loss: 0.0767
Epoch 14/15, Batch 250/938, Loss: 0.0798
Epoch 14/15, Batch 300/938, Loss: 0.0855
Epoch 14/15, Batch 350/938, Loss: 0.0857
Epoch 14/15, Batch 400/938, Loss: 0.0797
Epoch 14/15, Batch 450/938, Loss: 0.0820
Epoch 14/15, Batch 500/938, Loss: 0.0840
Epoch 14/15, Batch 550/938, Loss: 0.0817
Epoch 14/15, Batch 600/938, Loss: 0.0858
Epoch 14/15, Batch 650/938, Loss: 0.0837
Epoch 14/15, Batch 700/938, Loss: 0.0860
Epoch 14/15, Batch 750/938, Loss: 0.0861
Epoch 14/15, Batch 800/938, Loss: 0.0853
Epoch 14/15, Batch 850/938, Loss: 0.0867
Epoch 14/15, Batch 900/938, Loss: 0.0881
Epoch 14/15 completed. Test Loss: 0.0794
Epoch 15/15, Batch 0/938, Loss: 0.0810
Epoch 15/15, Batch 50/938, Loss: 0.0833
Epoch 15/15, Batch 100/938, Loss: 0.0777
Epoch 15/15, Batch 150/938, Loss: 0.0828
Epoch 15/15, Batch 200/938, Loss: 0.0837
Epoch 15/15, Batch 250/938, Loss: 0.0802
Epoch 15/15, Batch 300/938, Loss: 0.0861
Epoch 15/15, Batch 350/938, Loss: 0.0891
Epoch 15/15, Batch 400/938, Loss: 0.0805
Epoch 15/15, Batch 450/938, Loss: 0.0828
Epoch 15/15, Batch 500/938, Loss: 0.0743
Epoch 15/15, Batch 550/938, Loss: 0.0821
Epoch 15/15, Batch 600/938, Loss: 0.0784
Epoch 15/15, Batch 650/938, Loss: 0.0816
Epoch 15/15, Batch 700/938, Loss: 0.0816
Epoch 15/15, Batch 750/938, Loss: 0.0760
Epoch 15/15, Batch 800/938, Loss: 0.0820
Epoch 15/15, Batch 850/938, Loss: 0.0833
Epoch 15/15, Batch 900/938, Loss: 0.0869
Epoch 15/15 completed. Test Loss: 0.0793
Final Test Loss: 0.0793
No description has been provided for this image
samples shape:  (100, 28, 28, 1)
No description has been provided for this image

Part (b) iGPT on Colored Shapes and MNIST¶

Now, implement an iGPT that models color. In order to reduce the length of token sequences, iGPT models each RGB pixel as a single token. This effectively reduces the context length from HWC to just H*W. iGPT does this through a k-means clustering approach. Because our images only each can only take on 4 values (2 bits) per channel, we can represent each pixel with 64 values (6 bits). Convert the dataset into an image of tokens and train iGPT on the colored shapes and MNIST dataset.

Checkout the iGPT paper for more details: Generative Pretraining from Pixels

Training times and hyperparameter settings should be the same as part (a), except train for longer (15 epochs)

You will provide these deliverables

  1. Over the course of training, record the average negative log-likelihood (nats / dim) of the training data (per minibatch) and test data (for your entire test set). Code is provided that automatically plots the training curves.
  2. Report the final test set performance of your final model
  3. 100 samples from the final trained model
In [13]:
def q3_b(train_data, test_data, image_shape, dset_id):
    """
    train_data: A (n_train, H, W, C) uint8 numpy array of color images with values in {0, 1, 2, 3}
    test_data: A (n_test, H, W, C) uint8 numpy array of color images with values in {0, 1, 2, 3}
    image_shape: (H, W, C), height, width, and # of channels of the image
    dset_id: An identifying number of which dataset is given (1 or 2). Most likely
            used to set different hyperparameters for different datasets

    Returns
    - a (# of training iterations,) numpy array of train_losses evaluated every minibatch
    - a (# of epochs + 1,) numpy array of test_losses evaluated once at initialization and after each epoch
    - a numpy array of size (100, H, W, C) of samples with values in {0, 1, 2, 3}
    """
    batch_size = 64
    learning_rate = 1e-3
    num_epochs = 15
    
    # Model parameters as recommended in the instructions
    d_model = 128
    n_heads = 4
    n_layers = 2
        
    H, W, C = image_shape
    sequence_length = H * W + 1  # +1 for <bos> token
    vocab_size = 64  # each pixel be represented by 6 bits
    
    # Create datasets and data loaders
    train_loader = create_dataset(train_data, image_shape, batch_size)
    test_loader = create_dataset(test_data, image_shape, batch_size)
    # Initialize model and move to device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = iGPT(vocab_size, sequence_length, d_model, n_heads, n_layers).to(device)
    
    # Train the model
    train_losses, test_losses = train_igpt(model, train_loader, test_loader, 
                                          sequence_length, vocab_size, device,
                                          num_epochs, learning_rate)
    
    # save the model
    torch.save(model, f'model_colored_no_cache_{dset_id}.pth')
    # Generate samples
    samples ,_ = generate_samples(model, sequence_length, vocab_size, image_shape, device)
    
    return train_losses, test_losses, samples

Results¶

Once you've implemented q3_b, execute the cells below to visualize and save your results

In [14]:
q3ab_save_results(1, 'b', q3_b)
data_dir:  /home/nghiaph/workspace/deepul/homeworks/hw1/data
Epoch 1/15, Batch 0/164, Loss: 4.2606
Epoch 1/15, Batch 50/164, Loss: 3.9352
Epoch 1/15, Batch 100/164, Loss: 2.9784
Epoch 1/15, Batch 150/164, Loss: 1.8770
Epoch 1/15 completed. Test Loss: 1.3999
Epoch 2/15, Batch 0/164, Loss: 1.5793
Epoch 2/15, Batch 50/164, Loss: 0.7684
Epoch 2/15, Batch 100/164, Loss: 0.4597
Epoch 2/15, Batch 150/164, Loss: 0.3514
Epoch 2/15 completed. Test Loss: 0.2962
Epoch 3/15, Batch 0/164, Loss: 0.3175
Epoch 3/15, Batch 50/164, Loss: 0.2741
Epoch 3/15, Batch 100/164, Loss: 0.2230
Epoch 3/15, Batch 150/164, Loss: 0.1922
Epoch 3/15 completed. Test Loss: 0.1610
Epoch 4/15, Batch 0/164, Loss: 0.1853
Epoch 4/15, Batch 50/164, Loss: 0.1634
Epoch 4/15, Batch 100/164, Loss: 0.1535
Epoch 4/15, Batch 150/164, Loss: 0.1427
Epoch 4/15 completed. Test Loss: 0.1260
Epoch 5/15, Batch 0/164, Loss: 0.1391
Epoch 5/15, Batch 50/164, Loss: 0.1338
Epoch 5/15, Batch 100/164, Loss: 0.1241
Epoch 5/15, Batch 150/164, Loss: 0.1180
Epoch 5/15 completed. Test Loss: 0.1067
Epoch 6/15, Batch 0/164, Loss: 0.1192
Epoch 6/15, Batch 50/164, Loss: 0.1212
Epoch 6/15, Batch 100/164, Loss: 0.1110
Epoch 6/15, Batch 150/164, Loss: 0.1097
Epoch 6/15 completed. Test Loss: 0.0975
Epoch 7/15, Batch 0/164, Loss: 0.1115
Epoch 7/15, Batch 50/164, Loss: 0.1079
Epoch 7/15, Batch 100/164, Loss: 0.1004
Epoch 7/15, Batch 150/164, Loss: 0.0946
Epoch 7/15 completed. Test Loss: 0.0933
Epoch 8/15, Batch 0/164, Loss: 0.1037
Epoch 8/15, Batch 50/164, Loss: 0.0980
Epoch 8/15, Batch 100/164, Loss: 0.0954
Epoch 8/15, Batch 150/164, Loss: 0.1046
Epoch 8/15 completed. Test Loss: 0.0880
Epoch 9/15, Batch 0/164, Loss: 0.1014
Epoch 9/15, Batch 50/164, Loss: 0.0952
Epoch 9/15, Batch 100/164, Loss: 0.1003
Epoch 9/15, Batch 150/164, Loss: 0.0909
Epoch 9/15 completed. Test Loss: 0.0850
Epoch 10/15, Batch 0/164, Loss: 0.0897
Epoch 10/15, Batch 50/164, Loss: 0.0937
Epoch 10/15, Batch 100/164, Loss: 0.0871
Epoch 10/15, Batch 150/164, Loss: 0.0910
Epoch 10/15 completed. Test Loss: 0.0824
Epoch 11/15, Batch 0/164, Loss: 0.0855
Epoch 11/15, Batch 50/164, Loss: 0.0893
Epoch 11/15, Batch 100/164, Loss: 0.0896
Epoch 11/15, Batch 150/164, Loss: 0.0850
Epoch 11/15 completed. Test Loss: 0.0815
Epoch 12/15, Batch 0/164, Loss: 0.0802
Epoch 12/15, Batch 50/164, Loss: 0.0910
Epoch 12/15, Batch 100/164, Loss: 0.0933
Epoch 12/15, Batch 150/164, Loss: 0.0842
Epoch 12/15 completed. Test Loss: 0.0794
Epoch 13/15, Batch 0/164, Loss: 0.0824
Epoch 13/15, Batch 50/164, Loss: 0.0888
Epoch 13/15, Batch 100/164, Loss: 0.0850
Epoch 13/15, Batch 150/164, Loss: 0.0849
Epoch 13/15 completed. Test Loss: 0.0783
Epoch 14/15, Batch 0/164, Loss: 0.0828
Epoch 14/15, Batch 50/164, Loss: 0.0879
Epoch 14/15, Batch 100/164, Loss: 0.0794
Epoch 14/15, Batch 150/164, Loss: 0.0887
Epoch 14/15 completed. Test Loss: 0.0779
Epoch 15/15, Batch 0/164, Loss: 0.0864
Epoch 15/15, Batch 50/164, Loss: 0.0795
Epoch 15/15, Batch 100/164, Loss: 0.0850
Epoch 15/15, Batch 150/164, Loss: 0.0844
Epoch 15/15 completed. Test Loss: 0.0778
Final Test Loss: 0.0778
No description has been provided for this image
samples shape:  (100, 20, 20, 3)
No description has been provided for this image
In [15]:
q3ab_save_results(2, 'b', q3_b)
data_dir:  /home/nghiaph/workspace/deepul/homeworks/hw1/data
Epoch 1/15, Batch 0/938, Loss: 4.1959
Epoch 1/15, Batch 50/938, Loss: 3.3272
Epoch 1/15, Batch 100/938, Loss: 1.6369
Epoch 1/15, Batch 150/938, Loss: 1.1212
Epoch 1/15, Batch 200/938, Loss: 0.9332
Epoch 1/15, Batch 250/938, Loss: 0.8661
Epoch 1/15, Batch 300/938, Loss: 0.8122
Epoch 1/15, Batch 350/938, Loss: 0.8091
Epoch 1/15, Batch 400/938, Loss: 0.7632
Epoch 1/15, Batch 450/938, Loss: 0.7805
Epoch 1/15, Batch 500/938, Loss: 0.7259
Epoch 1/15, Batch 550/938, Loss: 0.6978
Epoch 1/15, Batch 600/938, Loss: 0.6527
Epoch 1/15, Batch 650/938, Loss: 0.6754
Epoch 1/15, Batch 700/938, Loss: 0.6195
Epoch 1/15, Batch 750/938, Loss: 0.6289
Epoch 1/15, Batch 800/938, Loss: 0.6234
Epoch 1/15, Batch 850/938, Loss: 0.6313
Epoch 1/15, Batch 900/938, Loss: 0.6226
Epoch 1/15 completed. Test Loss: 0.5847
Epoch 2/15, Batch 0/938, Loss: 0.5990
Epoch 2/15, Batch 50/938, Loss: 0.6081
Epoch 2/15, Batch 100/938, Loss: 0.6120
Epoch 2/15, Batch 150/938, Loss: 0.6272
Epoch 2/15, Batch 200/938, Loss: 0.5617
Epoch 2/15, Batch 250/938, Loss: 0.5834
Epoch 2/15, Batch 300/938, Loss: 0.5733
Epoch 2/15, Batch 350/938, Loss: 0.5822
Epoch 2/15, Batch 400/938, Loss: 0.5571
Epoch 2/15, Batch 450/938, Loss: 0.5379
Epoch 2/15, Batch 500/938, Loss: 0.5412
Epoch 2/15, Batch 550/938, Loss: 0.5793
Epoch 2/15, Batch 600/938, Loss: 0.5047
Epoch 2/15, Batch 650/938, Loss: 0.5169
Epoch 2/15, Batch 700/938, Loss: 0.5281
Epoch 2/15, Batch 750/938, Loss: 0.5090
Epoch 2/15, Batch 800/938, Loss: 0.5111
Epoch 2/15, Batch 850/938, Loss: 0.5154
Epoch 2/15, Batch 900/938, Loss: 0.4967
Epoch 2/15 completed. Test Loss: 0.4656
Epoch 3/15, Batch 0/938, Loss: 0.4917
Epoch 3/15, Batch 50/938, Loss: 0.5034
Epoch 3/15, Batch 100/938, Loss: 0.4944
Epoch 3/15, Batch 150/938, Loss: 0.4807
Epoch 3/15, Batch 200/938, Loss: 0.4832
Epoch 3/15, Batch 250/938, Loss: 0.4999
Epoch 3/15, Batch 300/938, Loss: 0.4862
Epoch 3/15, Batch 350/938, Loss: 0.4662
Epoch 3/15, Batch 400/938, Loss: 0.4932
Epoch 3/15, Batch 450/938, Loss: 0.4692
Epoch 3/15, Batch 500/938, Loss: 0.4872
Epoch 3/15, Batch 550/938, Loss: 0.4541
Epoch 3/15, Batch 600/938, Loss: 0.4566
Epoch 3/15, Batch 650/938, Loss: 0.4587
Epoch 3/15, Batch 700/938, Loss: 0.4554
Epoch 3/15, Batch 750/938, Loss: 0.4549
Epoch 3/15, Batch 800/938, Loss: 0.4302
Epoch 3/15, Batch 850/938, Loss: 0.4279
Epoch 3/15, Batch 900/938, Loss: 0.4309
Epoch 3/15 completed. Test Loss: 0.4147
Epoch 4/15, Batch 0/938, Loss: 0.4480
Epoch 4/15, Batch 50/938, Loss: 0.4450
Epoch 4/15, Batch 100/938, Loss: 0.4297
Epoch 4/15, Batch 150/938, Loss: 0.4473
Epoch 4/15, Batch 200/938, Loss: 0.4273
Epoch 4/15, Batch 250/938, Loss: 0.4424
Epoch 4/15, Batch 300/938, Loss: 0.4214
Epoch 4/15, Batch 350/938, Loss: 0.4466
Epoch 4/15, Batch 400/938, Loss: 0.4289
Epoch 4/15, Batch 450/938, Loss: 0.4011
Epoch 4/15, Batch 500/938, Loss: 0.4159
Epoch 4/15, Batch 550/938, Loss: 0.4178
Epoch 4/15, Batch 600/938, Loss: 0.4113
Epoch 4/15, Batch 650/938, Loss: 0.3959
Epoch 4/15, Batch 700/938, Loss: 0.4184
Epoch 4/15, Batch 750/938, Loss: 0.4069
Epoch 4/15, Batch 800/938, Loss: 0.4161
Epoch 4/15, Batch 850/938, Loss: 0.4211
Epoch 4/15, Batch 900/938, Loss: 0.3780
Epoch 4/15 completed. Test Loss: 0.3693
Epoch 5/15, Batch 0/938, Loss: 0.4028
Epoch 5/15, Batch 50/938, Loss: 0.4012
Epoch 5/15, Batch 100/938, Loss: 0.3897
Epoch 5/15, Batch 150/938, Loss: 0.3895
Epoch 5/15, Batch 200/938, Loss: 0.4058
Epoch 5/15, Batch 250/938, Loss: 0.3899
Epoch 5/15, Batch 300/938, Loss: 0.3730
Epoch 5/15, Batch 350/938, Loss: 0.3881
Epoch 5/15, Batch 400/938, Loss: 0.3705
Epoch 5/15, Batch 450/938, Loss: 0.3675
Epoch 5/15, Batch 500/938, Loss: 0.3640
Epoch 5/15, Batch 550/938, Loss: 0.3794
Epoch 5/15, Batch 600/938, Loss: 0.3628
Epoch 5/15, Batch 650/938, Loss: 0.3812
Epoch 5/15, Batch 700/938, Loss: 0.3693
Epoch 5/15, Batch 750/938, Loss: 0.3612
Epoch 5/15, Batch 800/938, Loss: 0.3775
Epoch 5/15, Batch 850/938, Loss: 0.3655
Epoch 5/15, Batch 900/938, Loss: 0.3602
Epoch 5/15 completed. Test Loss: 0.3128
Epoch 6/15, Batch 0/938, Loss: 0.3612
Epoch 6/15, Batch 50/938, Loss: 0.3532
Epoch 6/15, Batch 100/938, Loss: 0.3500
Epoch 6/15, Batch 150/938, Loss: 0.3439
Epoch 6/15, Batch 200/938, Loss: 0.3574
Epoch 6/15, Batch 250/938, Loss: 0.3536
Epoch 6/15, Batch 300/938, Loss: 0.3484
Epoch 6/15, Batch 350/938, Loss: 0.3546
Epoch 6/15, Batch 400/938, Loss: 0.3418
Epoch 6/15, Batch 450/938, Loss: 0.3334
Epoch 6/15, Batch 500/938, Loss: 0.3356
Epoch 6/15, Batch 550/938, Loss: 0.3387
Epoch 6/15, Batch 600/938, Loss: 0.3421
Epoch 6/15, Batch 650/938, Loss: 0.3313
Epoch 6/15, Batch 700/938, Loss: 0.3208
Epoch 6/15, Batch 750/938, Loss: 0.3231
Epoch 6/15, Batch 800/938, Loss: 0.3225
Epoch 6/15, Batch 850/938, Loss: 0.3241
Epoch 6/15, Batch 900/938, Loss: 0.3156
Epoch 6/15 completed. Test Loss: 0.2667
Epoch 7/15, Batch 0/938, Loss: 0.3260
Epoch 7/15, Batch 50/938, Loss: 0.3193
Epoch 7/15, Batch 100/938, Loss: 0.3169
Epoch 7/15, Batch 150/938, Loss: 0.3166
Epoch 7/15, Batch 200/938, Loss: 0.3136
Epoch 7/15, Batch 250/938, Loss: 0.2925
Epoch 7/15, Batch 300/938, Loss: 0.3076
Epoch 7/15, Batch 350/938, Loss: 0.2909
Epoch 7/15, Batch 400/938, Loss: 0.2994
Epoch 7/15, Batch 450/938, Loss: 0.2952
Epoch 7/15, Batch 500/938, Loss: 0.3022
Epoch 7/15, Batch 550/938, Loss: 0.3047
Epoch 7/15, Batch 600/938, Loss: 0.3051
Epoch 7/15, Batch 650/938, Loss: 0.2924
Epoch 7/15, Batch 700/938, Loss: 0.2902
Epoch 7/15, Batch 750/938, Loss: 0.3053
Epoch 7/15, Batch 800/938, Loss: 0.2979
Epoch 7/15, Batch 850/938, Loss: 0.2890
Epoch 7/15, Batch 900/938, Loss: 0.2833
Epoch 7/15 completed. Test Loss: 0.2301
Epoch 8/15, Batch 0/938, Loss: 0.2992
Epoch 8/15, Batch 50/938, Loss: 0.2773
Epoch 8/15, Batch 100/938, Loss: 0.2916
Epoch 8/15, Batch 150/938, Loss: 0.2849
Epoch 8/15, Batch 200/938, Loss: 0.2839
Epoch 8/15, Batch 250/938, Loss: 0.2995
Epoch 8/15, Batch 300/938, Loss: 0.2734
Epoch 8/15, Batch 350/938, Loss: 0.2864
Epoch 8/15, Batch 400/938, Loss: 0.2846
Epoch 8/15, Batch 450/938, Loss: 0.2717
Epoch 8/15, Batch 500/938, Loss: 0.2778
Epoch 8/15, Batch 550/938, Loss: 0.2761
Epoch 8/15, Batch 600/938, Loss: 0.2738
Epoch 8/15, Batch 650/938, Loss: 0.2879
Epoch 8/15, Batch 700/938, Loss: 0.2704
Epoch 8/15, Batch 750/938, Loss: 0.2780
Epoch 8/15, Batch 800/938, Loss: 0.2679
Epoch 8/15, Batch 850/938, Loss: 0.2630
Epoch 8/15, Batch 900/938, Loss: 0.2613
Epoch 8/15 completed. Test Loss: 0.2041
Epoch 9/15, Batch 0/938, Loss: 0.2680
Epoch 9/15, Batch 50/938, Loss: 0.2633
Epoch 9/15, Batch 100/938, Loss: 0.2661
Epoch 9/15, Batch 150/938, Loss: 0.2719
Epoch 9/15, Batch 200/938, Loss: 0.2641
Epoch 9/15, Batch 250/938, Loss: 0.2584
Epoch 9/15, Batch 300/938, Loss: 0.2552
Epoch 9/15, Batch 350/938, Loss: 0.2596
Epoch 9/15, Batch 400/938, Loss: 0.2643
Epoch 9/15, Batch 450/938, Loss: 0.2499
Epoch 9/15, Batch 500/938, Loss: 0.2623
Epoch 9/15, Batch 550/938, Loss: 0.2594
Epoch 9/15, Batch 600/938, Loss: 0.2533
Epoch 9/15, Batch 650/938, Loss: 0.2577
Epoch 9/15, Batch 700/938, Loss: 0.2555
Epoch 9/15, Batch 750/938, Loss: 0.2440
Epoch 9/15, Batch 800/938, Loss: 0.2421
Epoch 9/15, Batch 850/938, Loss: 0.2551
Epoch 9/15, Batch 900/938, Loss: 0.2501
Epoch 9/15 completed. Test Loss: 0.1880
Epoch 10/15, Batch 0/938, Loss: 0.2460
Epoch 10/15, Batch 50/938, Loss: 0.2493
Epoch 10/15, Batch 100/938, Loss: 0.2540
Epoch 10/15, Batch 150/938, Loss: 0.2462
Epoch 10/15, Batch 200/938, Loss: 0.2484
Epoch 10/15, Batch 250/938, Loss: 0.2506
Epoch 10/15, Batch 300/938, Loss: 0.2444
Epoch 10/15, Batch 350/938, Loss: 0.2459
Epoch 10/15, Batch 400/938, Loss: 0.2455
Epoch 10/15, Batch 450/938, Loss: 0.2488
Epoch 10/15, Batch 500/938, Loss: 0.2387
Epoch 10/15, Batch 550/938, Loss: 0.2427
Epoch 10/15, Batch 600/938, Loss: 0.2396
Epoch 10/15, Batch 650/938, Loss: 0.2390
Epoch 10/15, Batch 700/938, Loss: 0.2406
Epoch 10/15, Batch 750/938, Loss: 0.2400
Epoch 10/15, Batch 800/938, Loss: 0.2396
Epoch 10/15, Batch 850/938, Loss: 0.2486
Epoch 10/15, Batch 900/938, Loss: 0.2500
Epoch 10/15 completed. Test Loss: 0.1780
Epoch 11/15, Batch 0/938, Loss: 0.2323
Epoch 11/15, Batch 50/938, Loss: 0.2357
Epoch 11/15, Batch 100/938, Loss: 0.2415
Epoch 11/15, Batch 150/938, Loss: 0.2311
Epoch 11/15, Batch 200/938, Loss: 0.2305
Epoch 11/15, Batch 250/938, Loss: 0.2324
Epoch 11/15, Batch 300/938, Loss: 0.2336
Epoch 11/15, Batch 350/938, Loss: 0.2358
Epoch 11/15, Batch 400/938, Loss: 0.2434
Epoch 11/15, Batch 450/938, Loss: 0.2302
Epoch 11/15, Batch 500/938, Loss: 0.2353
Epoch 11/15, Batch 550/938, Loss: 0.2304
Epoch 11/15, Batch 600/938, Loss: 0.2422
Epoch 11/15, Batch 650/938, Loss: 0.2369
Epoch 11/15, Batch 700/938, Loss: 0.2270
Epoch 11/15, Batch 750/938, Loss: 0.2330
Epoch 11/15, Batch 800/938, Loss: 0.2259
Epoch 11/15, Batch 850/938, Loss: 0.2353
Epoch 11/15, Batch 900/938, Loss: 0.2394
Epoch 11/15 completed. Test Loss: 0.1705
Epoch 12/15, Batch 0/938, Loss: 0.2365
Epoch 12/15, Batch 50/938, Loss: 0.2298
Epoch 12/15, Batch 100/938, Loss: 0.2311
Epoch 12/15, Batch 150/938, Loss: 0.2342
Epoch 12/15, Batch 200/938, Loss: 0.2210
Epoch 12/15, Batch 250/938, Loss: 0.2292
Epoch 12/15, Batch 300/938, Loss: 0.2256
Epoch 12/15, Batch 350/938, Loss: 0.2278
Epoch 12/15, Batch 400/938, Loss: 0.2214
Epoch 12/15, Batch 450/938, Loss: 0.2271
Epoch 12/15, Batch 500/938, Loss: 0.2241
Epoch 12/15, Batch 550/938, Loss: 0.2269
Epoch 12/15, Batch 600/938, Loss: 0.2293
Epoch 12/15, Batch 650/938, Loss: 0.2342
Epoch 12/15, Batch 700/938, Loss: 0.2326
Epoch 12/15, Batch 750/938, Loss: 0.2253
Epoch 12/15, Batch 800/938, Loss: 0.2288
Epoch 12/15, Batch 850/938, Loss: 0.2196
Epoch 12/15, Batch 900/938, Loss: 0.2273
Epoch 12/15 completed. Test Loss: 0.1653
Epoch 13/15, Batch 0/938, Loss: 0.2240
Epoch 13/15, Batch 50/938, Loss: 0.2260
Epoch 13/15, Batch 100/938, Loss: 0.2244
Epoch 13/15, Batch 150/938, Loss: 0.2247
Epoch 13/15, Batch 200/938, Loss: 0.2240
Epoch 13/15, Batch 250/938, Loss: 0.2233
Epoch 13/15, Batch 300/938, Loss: 0.2121
Epoch 13/15, Batch 350/938, Loss: 0.2210
Epoch 13/15, Batch 400/938, Loss: 0.2178
Epoch 13/15, Batch 450/938, Loss: 0.2263
Epoch 13/15, Batch 500/938, Loss: 0.2255
Epoch 13/15, Batch 550/938, Loss: 0.2269
Epoch 13/15, Batch 600/938, Loss: 0.2210
Epoch 13/15, Batch 650/938, Loss: 0.2189
Epoch 13/15, Batch 700/938, Loss: 0.2343
Epoch 13/15, Batch 750/938, Loss: 0.2299
Epoch 13/15, Batch 800/938, Loss: 0.2213
Epoch 13/15, Batch 850/938, Loss: 0.2170
Epoch 13/15, Batch 900/938, Loss: 0.2248
Epoch 13/15 completed. Test Loss: 0.1622
Epoch 14/15, Batch 0/938, Loss: 0.2273
Epoch 14/15, Batch 50/938, Loss: 0.2216
Epoch 14/15, Batch 100/938, Loss: 0.2272
Epoch 14/15, Batch 150/938, Loss: 0.2236
Epoch 14/15, Batch 200/938, Loss: 0.2252
Epoch 14/15, Batch 250/938, Loss: 0.2224
Epoch 14/15, Batch 300/938, Loss: 0.2252
Epoch 14/15, Batch 350/938, Loss: 0.2267
Epoch 14/15, Batch 400/938, Loss: 0.2265
Epoch 14/15, Batch 450/938, Loss: 0.2222
Epoch 14/15, Batch 500/938, Loss: 0.2239
Epoch 14/15, Batch 550/938, Loss: 0.2202
Epoch 14/15, Batch 600/938, Loss: 0.2263
Epoch 14/15, Batch 650/938, Loss: 0.2195
Epoch 14/15, Batch 700/938, Loss: 0.2183
Epoch 14/15, Batch 750/938, Loss: 0.2260
Epoch 14/15, Batch 800/938, Loss: 0.2205
Epoch 14/15, Batch 850/938, Loss: 0.2257
Epoch 14/15, Batch 900/938, Loss: 0.2245
Epoch 14/15 completed. Test Loss: 0.1611
Epoch 15/15, Batch 0/938, Loss: 0.2189
Epoch 15/15, Batch 50/938, Loss: 0.2202
Epoch 15/15, Batch 100/938, Loss: 0.2225
Epoch 15/15, Batch 150/938, Loss: 0.2225
Epoch 15/15, Batch 200/938, Loss: 0.2239
Epoch 15/15, Batch 250/938, Loss: 0.2229
Epoch 15/15, Batch 300/938, Loss: 0.2234
Epoch 15/15, Batch 350/938, Loss: 0.2252
Epoch 15/15, Batch 400/938, Loss: 0.2178
Epoch 15/15, Batch 450/938, Loss: 0.2194
Epoch 15/15, Batch 500/938, Loss: 0.2283
Epoch 15/15, Batch 550/938, Loss: 0.2179
Epoch 15/15, Batch 600/938, Loss: 0.2217
Epoch 15/15, Batch 650/938, Loss: 0.2178
Epoch 15/15, Batch 700/938, Loss: 0.2236
Epoch 15/15, Batch 750/938, Loss: 0.2199
Epoch 15/15, Batch 800/938, Loss: 0.2239
Epoch 15/15, Batch 850/938, Loss: 0.2285
Epoch 15/15, Batch 900/938, Loss: 0.2182
Epoch 15/15 completed. Test Loss: 0.1610
Final Test Loss: 0.1610
No description has been provided for this image
samples shape:  (100, 28, 28, 3)
No description has been provided for this image

Part (c) K, V Caching for Improved Inference¶

You may have noticed that generation from the transformer is quite slow. Part of this is just due to the autoregressive nature. However, another part is due to some computational inefficiency. At each forward pass of the model, we are performing repeat computation of the past sequence. Specifically, we can cache the key and values at the multi attention layer to more quickly predict at each step.

In self-attention, a sequence is processed by generating three vectors for each element in the sequence: a Query (Q), a Key (K), and a Value (V). These vectors are then used to compute attention scores and subsequently the output of the attention layer. Mathematically, this can be represented as:

  • For each index $i$, compute $Q_i$, $K_i$, $V_i$ for the current element
  • Retrieve $K_{<i}$ and $V_{<i}$ from the cache (where $<i$ denotes all indices before the current one)
  • Compute the attention output using $Q_i$, $[K_{<i}, K_i]$, $[V_{<i}, V_i]$

Next implement caching for your transformer to make inference more efficient by modifying your self attention. Use caching for inference in the future problems for faster generation! (Note caching is only used during inference). You will use the same dataset as in part B, dataset 2 of this question (colored mnist). No training is required in this section, feel free to reuse the model you trained in part B, dataset 2.

You will provide these deliverables

  1. Over the course of inference, measure the time for the forward pass over the total sequence length with and without caching.
  2. 100 samples from the final trained model using the caching inference pipeline.
In [20]:
def q3_c(train_data, test_data, image_shape, dset_id):
    import os
    """
    train_data: A (n_train, H, W, C) uint8 numpy array of color images with values in {0, 1, 2, 3}
    test_data: A (n_test, H, W, C) uint8 numpy array of color images with values in {0, 1, 2, 3}
    image_shape: (H, W, C), height, width, and # of channels of the image
    dset_id: An identifying number of which dataset is given (1 or 2). Most likely
             used to set different hyperparameters for different datasets

    Returns
    - a (# sampling steps,) numpy array of time per sampling iteration, without caching
    - a (# sampling steps,) numpy array of time per sampling iteration, with caching
    - a numpy array of size (100, H, W, C) of samples with values in {0, 1, 2, 3} (sample generated without caching)
    - a numpy array of size (100, H, W, C) of samples with values in {0, 1, 2, 3} (sample generated with caching)
    """
    # Model hyperparameters
    batch_size = 64
    learning_rate = 1e-3
    num_epochs = 15
    
    # Transformer architecture parameters
    d_model = 128
    n_heads = 4
    n_layers = 2
        
    H, W, C = image_shape
    print("image shape: ", image_shape)
    sequence_length = H * W + 1  # +1 for <bos> token
    vocab_size = 64  # each pixel represented by 6 bits
    
    # Create datasets and data loaders
    train_loader = create_dataset(train_data, image_shape, batch_size)
    test_loader = create_dataset(test_data, image_shape, batch_size)
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load or train the model
    model = iGPT(vocab_size, sequence_length, d_model, n_heads, n_layers).to(device)
    train_losses, test_losses = train_igpt(model, train_loader, test_loader, 
                                            sequence_length, vocab_size, device,
                                            num_epochs, learning_rate)
 

    
    # Generate samples without caching and measure time
    # start_time = time.time()
    samples_no_cache, time_list_no_cache = generate_samples(model, sequence_length, vocab_size, image_shape, device, use_cache=False, test_mode=False)
    
    # Generate samples with caching and measure time
    samples_with_cache, time_list_with_cache = generate_samples(model, sequence_length, vocab_size, image_shape, device, use_cache=True, test_mode=False)
    # print(f"Speedup: {total_time_no_cache / total_time_with_cache:.2f}x")
    

    return time_list_no_cache, time_list_with_cache, samples_no_cache, samples_with_cache

Results¶

Once you've implemented q3_c, execute the cells below to visualize and save your results

In [21]:
q3c_save_results(2, q3_c)
data_dir:  /home/nghiaph/workspace/deepul/homeworks/hw1/data
image shape:  (28, 28, 3)
Epoch 1/15, Batch 0/938, Loss: 4.2718
Epoch 1/15, Batch 50/938, Loss: 3.5423
Epoch 1/15, Batch 100/938, Loss: 1.7781
Epoch 1/15, Batch 150/938, Loss: 1.1359
Epoch 1/15, Batch 200/938, Loss: 0.9425
Epoch 1/15, Batch 250/938, Loss: 0.8162
Epoch 1/15, Batch 300/938, Loss: 0.8538
Epoch 1/15, Batch 350/938, Loss: 0.7669
Epoch 1/15, Batch 400/938, Loss: 0.7949
Epoch 1/15, Batch 450/938, Loss: 0.7809
Epoch 1/15, Batch 500/938, Loss: 0.7572
Epoch 1/15, Batch 550/938, Loss: 0.7087
Epoch 1/15, Batch 600/938, Loss: 0.6779
Epoch 1/15, Batch 650/938, Loss: 0.6718
Epoch 1/15, Batch 700/938, Loss: 0.6426
Epoch 1/15, Batch 750/938, Loss: 0.6188
Epoch 1/15, Batch 800/938, Loss: 0.6319
Epoch 1/15, Batch 850/938, Loss: 0.6013
Epoch 1/15, Batch 900/938, Loss: 0.6063
Epoch 1/15 completed. Test Loss: 0.5818
Epoch 2/15, Batch 0/938, Loss: 0.5846
Epoch 2/15, Batch 50/938, Loss: 0.5855
Epoch 2/15, Batch 100/938, Loss: 0.5834
Epoch 2/15, Batch 150/938, Loss: 0.5655
Epoch 2/15, Batch 200/938, Loss: 0.5744
Epoch 2/15, Batch 250/938, Loss: 0.5381
Epoch 2/15, Batch 300/938, Loss: 0.5357
Epoch 2/15, Batch 350/938, Loss: 0.5642
Epoch 2/15, Batch 400/938, Loss: 0.5569
Epoch 2/15, Batch 450/938, Loss: 0.5557
Epoch 2/15, Batch 500/938, Loss: 0.5126
Epoch 2/15, Batch 550/938, Loss: 0.5367
Epoch 2/15, Batch 600/938, Loss: 0.5142
Epoch 2/15, Batch 650/938, Loss: 0.5057
Epoch 2/15, Batch 700/938, Loss: 0.4898
Epoch 2/15, Batch 750/938, Loss: 0.4923
Epoch 2/15, Batch 800/938, Loss: 0.5167
Epoch 2/15, Batch 850/938, Loss: 0.4860
Epoch 2/15, Batch 900/938, Loss: 0.4951
Epoch 2/15 completed. Test Loss: 0.4632
Epoch 3/15, Batch 0/938, Loss: 0.4827
Epoch 3/15, Batch 50/938, Loss: 0.4902
Epoch 3/15, Batch 100/938, Loss: 0.4752
Epoch 3/15, Batch 150/938, Loss: 0.4781
Epoch 3/15, Batch 200/938, Loss: 0.4612
Epoch 3/15, Batch 250/938, Loss: 0.4779
Epoch 3/15, Batch 300/938, Loss: 0.4838
Epoch 3/15, Batch 350/938, Loss: 0.4754
Epoch 3/15, Batch 400/938, Loss: 0.4845
Epoch 3/15, Batch 450/938, Loss: 0.4675
Epoch 3/15, Batch 500/938, Loss: 0.4605
Epoch 3/15, Batch 550/938, Loss: 0.4582
Epoch 3/15, Batch 600/938, Loss: 0.4550
Epoch 3/15, Batch 650/938, Loss: 0.4428
Epoch 3/15, Batch 700/938, Loss: 0.4609
Epoch 3/15, Batch 750/938, Loss: 0.4617
Epoch 3/15, Batch 800/938, Loss: 0.4544
Epoch 3/15, Batch 850/938, Loss: 0.4646
Epoch 3/15, Batch 900/938, Loss: 0.4584
Epoch 3/15 completed. Test Loss: 0.4146
Epoch 4/15, Batch 0/938, Loss: 0.4440
Epoch 4/15, Batch 50/938, Loss: 0.4569
Epoch 4/15, Batch 100/938, Loss: 0.4488
Epoch 4/15, Batch 150/938, Loss: 0.4396
Epoch 4/15, Batch 200/938, Loss: 0.4434
Epoch 4/15, Batch 250/938, Loss: 0.4332
Epoch 4/15, Batch 300/938, Loss: 0.4298
Epoch 4/15, Batch 350/938, Loss: 0.4441
Epoch 4/15, Batch 400/938, Loss: 0.4122
Epoch 4/15, Batch 450/938, Loss: 0.4090
Epoch 4/15, Batch 500/938, Loss: 0.4317
Epoch 4/15, Batch 550/938, Loss: 0.4174
Epoch 4/15, Batch 600/938, Loss: 0.4154
Epoch 4/15, Batch 650/938, Loss: 0.4222
Epoch 4/15, Batch 700/938, Loss: 0.4183
Epoch 4/15, Batch 750/938, Loss: 0.4125
Epoch 4/15, Batch 800/938, Loss: 0.4120
Epoch 4/15, Batch 850/938, Loss: 0.4163
Epoch 4/15, Batch 900/938, Loss: 0.3967
Epoch 4/15 completed. Test Loss: 0.3685
Epoch 5/15, Batch 0/938, Loss: 0.4068
Epoch 5/15, Batch 50/938, Loss: 0.4104
Epoch 5/15, Batch 100/938, Loss: 0.3983
Epoch 5/15, Batch 150/938, Loss: 0.3979
Epoch 5/15, Batch 200/938, Loss: 0.4037
Epoch 5/15, Batch 250/938, Loss: 0.4143
Epoch 5/15, Batch 300/938, Loss: 0.3918
Epoch 5/15, Batch 350/938, Loss: 0.3931
Epoch 5/15, Batch 400/938, Loss: 0.3846
Epoch 5/15, Batch 450/938, Loss: 0.3951
Epoch 5/15, Batch 500/938, Loss: 0.3836
Epoch 5/15, Batch 550/938, Loss: 0.3932
Epoch 5/15, Batch 600/938, Loss: 0.3683
Epoch 5/15, Batch 650/938, Loss: 0.3887
Epoch 5/15, Batch 700/938, Loss: 0.3801
Epoch 5/15, Batch 750/938, Loss: 0.3589
Epoch 5/15, Batch 800/938, Loss: 0.3982
Epoch 5/15, Batch 850/938, Loss: 0.3848
Epoch 5/15, Batch 900/938, Loss: 0.3697
Epoch 5/15 completed. Test Loss: 0.3261
Epoch 6/15, Batch 0/938, Loss: 0.3788
Epoch 6/15, Batch 50/938, Loss: 0.3716
Epoch 6/15, Batch 100/938, Loss: 0.3850
Epoch 6/15, Batch 150/938, Loss: 0.3589
Epoch 6/15, Batch 200/938, Loss: 0.3540
Epoch 6/15, Batch 250/938, Loss: 0.3504
Epoch 6/15, Batch 300/938, Loss: 0.3534
Epoch 6/15, Batch 350/938, Loss: 0.3633
Epoch 6/15, Batch 400/938, Loss: 0.3642
Epoch 6/15, Batch 450/938, Loss: 0.3338
Epoch 6/15, Batch 500/938, Loss: 0.3560
Epoch 6/15, Batch 550/938, Loss: 0.3586
Epoch 6/15, Batch 600/938, Loss: 0.3443
Epoch 6/15, Batch 650/938, Loss: 0.3518
Epoch 6/15, Batch 700/938, Loss: 0.3498
Epoch 6/15, Batch 750/938, Loss: 0.3511
Epoch 6/15, Batch 800/938, Loss: 0.3532
Epoch 6/15, Batch 850/938, Loss: 0.3329
Epoch 6/15, Batch 900/938, Loss: 0.3504
Epoch 6/15 completed. Test Loss: 0.2919
Epoch 7/15, Batch 0/938, Loss: 0.3376
Epoch 7/15, Batch 50/938, Loss: 0.3346
Epoch 7/15, Batch 100/938, Loss: 0.3414
Epoch 7/15, Batch 150/938, Loss: 0.3408
Epoch 7/15, Batch 200/938, Loss: 0.3296
Epoch 7/15, Batch 250/938, Loss: 0.3435
Epoch 7/15, Batch 300/938, Loss: 0.3412
Epoch 7/15, Batch 350/938, Loss: 0.3284
Epoch 7/15, Batch 400/938, Loss: 0.3412
Epoch 7/15, Batch 450/938, Loss: 0.3272
Epoch 7/15, Batch 500/938, Loss: 0.3277
Epoch 7/15, Batch 550/938, Loss: 0.3270
Epoch 7/15, Batch 600/938, Loss: 0.3257
Epoch 7/15, Batch 650/938, Loss: 0.3213
Epoch 7/15, Batch 700/938, Loss: 0.3168
Epoch 7/15, Batch 750/938, Loss: 0.3126
Epoch 7/15, Batch 800/938, Loss: 0.3260
Epoch 7/15, Batch 850/938, Loss: 0.3258
Epoch 7/15, Batch 900/938, Loss: 0.3150
Epoch 7/15 completed. Test Loss: 0.2624
Epoch 8/15, Batch 0/938, Loss: 0.3114
Epoch 8/15, Batch 50/938, Loss: 0.3081
Epoch 8/15, Batch 100/938, Loss: 0.3103
Epoch 8/15, Batch 150/938, Loss: 0.2964
Epoch 8/15, Batch 200/938, Loss: 0.3144
Epoch 8/15, Batch 250/938, Loss: 0.3217
Epoch 8/15, Batch 300/938, Loss: 0.3127
Epoch 8/15, Batch 350/938, Loss: 0.2814
Epoch 8/15, Batch 400/938, Loss: 0.3045
Epoch 8/15, Batch 450/938, Loss: 0.3139
Epoch 8/15, Batch 500/938, Loss: 0.2926
Epoch 8/15, Batch 550/938, Loss: 0.3015
Epoch 8/15, Batch 600/938, Loss: 0.2943
Epoch 8/15, Batch 650/938, Loss: 0.3153
Epoch 8/15, Batch 700/938, Loss: 0.2873
Epoch 8/15, Batch 750/938, Loss: 0.2926
Epoch 8/15, Batch 800/938, Loss: 0.2910
Epoch 8/15, Batch 850/938, Loss: 0.3043
Epoch 8/15, Batch 900/938, Loss: 0.3068
Epoch 8/15 completed. Test Loss: 0.2360
Epoch 9/15, Batch 0/938, Loss: 0.3018
Epoch 9/15, Batch 50/938, Loss: 0.3065
Epoch 9/15, Batch 100/938, Loss: 0.2952
Epoch 9/15, Batch 150/938, Loss: 0.2899
Epoch 9/15, Batch 200/938, Loss: 0.2881
Epoch 9/15, Batch 250/938, Loss: 0.2879
Epoch 9/15, Batch 300/938, Loss: 0.3020
Epoch 9/15, Batch 350/938, Loss: 0.2875
Epoch 9/15, Batch 400/938, Loss: 0.2785
Epoch 9/15, Batch 450/938, Loss: 0.2853
Epoch 9/15, Batch 500/938, Loss: 0.2670
Epoch 9/15, Batch 550/938, Loss: 0.2865
Epoch 9/15, Batch 600/938, Loss: 0.2784
Epoch 9/15, Batch 650/938, Loss: 0.2787
Epoch 9/15, Batch 700/938, Loss: 0.2908
Epoch 9/15, Batch 750/938, Loss: 0.2792
Epoch 9/15, Batch 800/938, Loss: 0.2856
Epoch 9/15, Batch 850/938, Loss: 0.2766
Epoch 9/15, Batch 900/938, Loss: 0.2750
Epoch 9/15 completed. Test Loss: 0.2178
Epoch 10/15, Batch 0/938, Loss: 0.2697
Epoch 10/15, Batch 50/938, Loss: 0.2798
Epoch 10/15, Batch 100/938, Loss: 0.2720
Epoch 10/15, Batch 150/938, Loss: 0.2746
Epoch 10/15, Batch 200/938, Loss: 0.2831
Epoch 10/15, Batch 250/938, Loss: 0.2805
Epoch 10/15, Batch 300/938, Loss: 0.2684
Epoch 10/15, Batch 350/938, Loss: 0.2604
Epoch 10/15, Batch 400/938, Loss: 0.2655
Epoch 10/15, Batch 450/938, Loss: 0.2691
Epoch 10/15, Batch 500/938, Loss: 0.2618
Epoch 10/15, Batch 550/938, Loss: 0.2704
Epoch 10/15, Batch 600/938, Loss: 0.2702
Epoch 10/15, Batch 650/938, Loss: 0.2603
Epoch 10/15, Batch 700/938, Loss: 0.2695
Epoch 10/15, Batch 750/938, Loss: 0.2712
Epoch 10/15, Batch 800/938, Loss: 0.2698
Epoch 10/15, Batch 850/938, Loss: 0.2727
Epoch 10/15, Batch 900/938, Loss: 0.2674
Epoch 10/15 completed. Test Loss: 0.2039
Epoch 11/15, Batch 0/938, Loss: 0.2671
Epoch 11/15, Batch 50/938, Loss: 0.2717
Epoch 11/15, Batch 100/938, Loss: 0.2713
Epoch 11/15, Batch 150/938, Loss: 0.2608
Epoch 11/15, Batch 200/938, Loss: 0.2629
Epoch 11/15, Batch 250/938, Loss: 0.2695
Epoch 11/15, Batch 300/938, Loss: 0.2711
Epoch 11/15, Batch 350/938, Loss: 0.2676
Epoch 11/15, Batch 400/938, Loss: 0.2611
Epoch 11/15, Batch 450/938, Loss: 0.2552
Epoch 11/15, Batch 500/938, Loss: 0.2589
Epoch 11/15, Batch 550/938, Loss: 0.2625
Epoch 11/15, Batch 600/938, Loss: 0.2622
Epoch 11/15, Batch 650/938, Loss: 0.2644
Epoch 11/15, Batch 700/938, Loss: 0.2656
Epoch 11/15, Batch 750/938, Loss: 0.2547
Epoch 11/15, Batch 800/938, Loss: 0.2615
Epoch 11/15, Batch 850/938, Loss: 0.2606
Epoch 11/15, Batch 900/938, Loss: 0.2508
Epoch 11/15 completed. Test Loss: 0.1964
Epoch 12/15, Batch 0/938, Loss: 0.2681
Epoch 12/15, Batch 50/938, Loss: 0.2654
Epoch 12/15, Batch 100/938, Loss: 0.2444
Epoch 12/15, Batch 150/938, Loss: 0.2510
Epoch 12/15, Batch 200/938, Loss: 0.2604
Epoch 12/15, Batch 250/938, Loss: 0.2544
Epoch 12/15, Batch 300/938, Loss: 0.2586
Epoch 12/15, Batch 350/938, Loss: 0.2608
Epoch 12/15, Batch 400/938, Loss: 0.2521
Epoch 12/15, Batch 450/938, Loss: 0.2511
Epoch 12/15, Batch 500/938, Loss: 0.2401
Epoch 12/15, Batch 550/938, Loss: 0.2576
Epoch 12/15, Batch 600/938, Loss: 0.2604
Epoch 12/15, Batch 650/938, Loss: 0.2551
Epoch 12/15, Batch 700/938, Loss: 0.2525
Epoch 12/15, Batch 750/938, Loss: 0.2486
Epoch 12/15, Batch 800/938, Loss: 0.2555
Epoch 12/15, Batch 850/938, Loss: 0.2551
Epoch 12/15, Batch 900/938, Loss: 0.2485
Epoch 12/15 completed. Test Loss: 0.1897
Epoch 13/15, Batch 0/938, Loss: 0.2502
Epoch 13/15, Batch 50/938, Loss: 0.2573
Epoch 13/15, Batch 100/938, Loss: 0.2535
Epoch 13/15, Batch 150/938, Loss: 0.2527
Epoch 13/15, Batch 200/938, Loss: 0.2568
Epoch 13/15, Batch 250/938, Loss: 0.2567
Epoch 13/15, Batch 300/938, Loss: 0.2436
Epoch 13/15, Batch 350/938, Loss: 0.2546
Epoch 13/15, Batch 400/938, Loss: 0.2506
Epoch 13/15, Batch 450/938, Loss: 0.2578
Epoch 13/15, Batch 500/938, Loss: 0.2490
Epoch 13/15, Batch 550/938, Loss: 0.2463
Epoch 13/15, Batch 600/938, Loss: 0.2582
Epoch 13/15, Batch 650/938, Loss: 0.2472
Epoch 13/15, Batch 700/938, Loss: 0.2518
Epoch 13/15, Batch 750/938, Loss: 0.2475
Epoch 13/15, Batch 800/938, Loss: 0.2477
Epoch 13/15, Batch 850/938, Loss: 0.2563
Epoch 13/15, Batch 900/938, Loss: 0.2399
Epoch 13/15 completed. Test Loss: 0.1862
Epoch 14/15, Batch 0/938, Loss: 0.2519
Epoch 14/15, Batch 50/938, Loss: 0.2650
Epoch 14/15, Batch 100/938, Loss: 0.2491
Epoch 14/15, Batch 150/938, Loss: 0.2511
Epoch 14/15, Batch 200/938, Loss: 0.2407
Epoch 14/15, Batch 250/938, Loss: 0.2418
Epoch 14/15, Batch 300/938, Loss: 0.2521
Epoch 14/15, Batch 350/938, Loss: 0.2506
Epoch 14/15, Batch 400/938, Loss: 0.2489
Epoch 14/15, Batch 450/938, Loss: 0.2420
Epoch 14/15, Batch 500/938, Loss: 0.2482
Epoch 14/15, Batch 550/938, Loss: 0.2591
Epoch 14/15, Batch 600/938, Loss: 0.2382
Epoch 14/15, Batch 650/938, Loss: 0.2449
Epoch 14/15, Batch 700/938, Loss: 0.2440
Epoch 14/15, Batch 750/938, Loss: 0.2473
Epoch 14/15, Batch 800/938, Loss: 0.2421
Epoch 14/15, Batch 850/938, Loss: 0.2516
Epoch 14/15, Batch 900/938, Loss: 0.2431
Epoch 14/15 completed. Test Loss: 0.1843
Epoch 15/15, Batch 0/938, Loss: 0.2453
Epoch 15/15, Batch 50/938, Loss: 0.2596
Epoch 15/15, Batch 100/938, Loss: 0.2551
Epoch 15/15, Batch 150/938, Loss: 0.2485
Epoch 15/15, Batch 200/938, Loss: 0.2414
Epoch 15/15, Batch 250/938, Loss: 0.2397
Epoch 15/15, Batch 300/938, Loss: 0.2395
Epoch 15/15, Batch 350/938, Loss: 0.2410
Epoch 15/15, Batch 400/938, Loss: 0.2474
Epoch 15/15, Batch 450/938, Loss: 0.2586
Epoch 15/15, Batch 500/938, Loss: 0.2450
Epoch 15/15, Batch 550/938, Loss: 0.2593
Epoch 15/15, Batch 600/938, Loss: 0.2470
Epoch 15/15, Batch 650/938, Loss: 0.2482
Epoch 15/15, Batch 700/938, Loss: 0.2426
Epoch 15/15, Batch 750/938, Loss: 0.2484
Epoch 15/15, Batch 800/938, Loss: 0.2543
Epoch 15/15, Batch 850/938, Loss: 0.2503
Epoch 15/15, Batch 900/938, Loss: 0.2564
Epoch 15/15 completed. Test Loss: 0.1841
No description has been provided for this image
samples shape:  (100, 28, 28, 3)
No description has been provided for this image
samples shape:  (100, 28, 28, 3)
No description has been provided for this image

Question 4: Causal Transformer: Tokenized Images¶

Image Tokenization with Vector Quanization¶

Part (a) Image Quantization¶

Above, we implemented iGPT, which autoregressivly predicts raw pixels. Transformers have quadratic complexity in the sequence length which prevents this naive approach from scaling well to large images.

The space of natural images often contains very correlated information. This suggests we can learn a reduced representation. VQVAE is a method that does just that, learning to map images to a more compact discrete set of tokens. We will cover this method in more detail in future lectures. The only thing you need to know now is that we can learn an encoder (and corresponding decoder), which can extract a discrete representation from an image.

If you are curious, checkout the VQVAE paper to learn more: https://arxiv.org/abs/1711.00937 (we will cover this in a future lecture though!)

In this part, we provide a pre-trained VQVAE model, which consists of:

  • encoder to tokenize the images
  • the decoder to recover the image
  • a token vocabulary of VQVAE_MODEL.n_embeddings

Below is the code for loading the VQ model. Note that VQVAE encoding process is lossy, so the decoded images will not be the exact same as the input. Some blurriness in the recovered image is to be expected. The docstrings of the relevant methods you will need for the VQVAE_MODEL are provided below for your convenience.

We will use 2 colored mnist datasets in this part. The first is the same dataset used in previous parts. The second, hads a colored digit on a differently colored background. We will call these datasets Colored MNIST and Colored MNIST v2. Note that the vqvae is trained per dataset.

You will provide these deliverables

  1. Use the provided encoder model to quantize the images then inspect the recovered images by applying the decoder for each of the two datasets
In [22]:
# @property
# def n_embeddings(self) -> int:
#     """The size of the token vocabulary"""
#    
# def quantize(self, x: np.ndarray) -> np.ndarray:
#     """Quantize an image x.
#
#     Args:
#         x (np.ndarray, dtype=int): Image to quantize. shape=(batch_size, 28, 28, 3). Values in [0, 3].
#
#     Returns:
#         np.ndarray: Quantized image. shape=(batch_size, 7, 7). Values in [0, n_embeddings]
#     """
#    
# def decode(self, z_index: np.ndarray) -> np.ndarray:
#     """Decode a quantized image.
#
#     Args:
#         z_index (np.ndarray, dtype=int): Quantized image. shape=(batch_size, 7, 7). Values in [0, n_embeddings].
#
#     Returns:
#         np.ndarray: Decoded image. shape=(batch_size, 28, 28, 3). Values in [0, 3].
#     """
# 
In [23]:
def q4_a(images, vqvae):
    """
    images: (B, H, W, C), the images to pass through the encoder and decoder of the vqvae
    vqvae: a vqvae model, trained on the relevant dataset

    Returns
    - a numpy array of size (2, H, W, C) of the decoded image
    """
    print(vqvae.n_embeddings)
    
    quantized_images = vqvae.quantize(images)
    # print shape of quantized_images
    print("quantinzed images:", quantized_images)
    print("quantized_images shape: ", quantized_images.shape)
    autoencoded_images = vqvae.decode(quantized_images)
    return autoencoded_images
In [24]:
q4a_save_results(1, q4_a)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.061281282..1.1016651].
data_dir:  /home/nghiaph/workspace/deepul/homeworks/hw1/data
data_dir:  /home/nghiaph/workspace/deepul/homeworks/hw1/data
1024
quantinzed images: tensor([[[ 161,  198,  121,  218,  645,  171,  272],
         [ 264,  191,  110,  193,  844,  334,  440],
         [ 935,  730,  386, 1020,  657,  218,  260],
         [ 145,  544,  730,  835,  702,  508,   96],
         [1014,  134,  722,  906,  738,  697,  811],
         [ 884,  268,  165,   94,  952,  821,  346],
         [ 228,  647,  429,  722,  982,  872,  582]],

        [[ 579,  228,  811,  219,  811,  569,   57],
         [ 749,  699,   11,  305,  925,  830,  395],
         [ 145,  593,  907,  422,  421,  533,  130],
         [ 769,  429,  342,  201,  261,  309,  348],
         [ 272,  609,  409,  884,  253,   19,  643],
         [ 250,  740,  465,  253,  772,  264,  228],
         [ 376,  534,  832,   18,  922,  134,  354]]])
quantized_images shape:  torch.Size([2, 7, 7])
samples shape:  (4, 28, 28, 3)
No description has been provided for this image
In [25]:
q4a_save_results(2, q4_a)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.08431109..1.1520311].
data_dir:  /home/nghiaph/workspace/deepul/homeworks/hw1/data
data_dir:  /home/nghiaph/workspace/deepul/homeworks/hw1/data
1024
quantinzed images: tensor([[[288,  75, 641,  75, 907, 907, 288],
         [402, 964, 265, 636, 425, 427, 402],
         [907, 993, 616, 504, 847, 718, 402],
         [ 75, 896, 883, 274, 888, 288, 421],
         [ 75, 641, 964, 419, 432, 421, 421],
         [336, 451, 859, 904, 117, 402, 288],
         [117, 694, 330, 336, 402, 288, 421]],

        [[334, 779, 334, 226, 637, 779, 242],
         [637, 950, 132, 914, 922, 802, 779],
         [179, 253, 651, 167, 937, 713, 779],
         [779, 675, 231, 132, 179, 939, 253],
         [779, 928, 380, 435, 369, 136, 468],
         [779, 928, 939, 859, 211, 625, 637],
         [779, 334, 309, 435, 242, 468, 637]]])
quantized_images shape:  torch.Size([2, 7, 7])
samples shape:  (4, 28, 28, 3)
No description has been provided for this image

Part (b) Autoregressive Transformer on Colored Shapes and MNIST with Vector Quantization¶

We can use the VQVAE to tokenize an image dataset. This will result in a much smaller sequence length than the approach we tried in Question 3(b). For this part, train a transformer on the dataset tokenized by the VQVAE.

This is a simplified version of the approach used in VQGAN VQGAN -> Section 3.2: Learning the Composition of Images with Transformers (Again, we will cover this in more detail in a future lecture!)

Update the following hyperparameters:

  • layers: 4 (we can train a bigger transformer now since less memory is used per input!)
  • 30 epochs

You will provide these deliverables

  1. Over the course of training, record the average negative log-likelihood (nats / dim) of the training data (per minibatch) and test data (for your entire test set). Code is provided that automatically plots the training curves.
  2. Report the final test set performance of your final model
  3. 100 samples from the final trained model
In [26]:
def create_tokenized_data(data, image_shape, batch_size, vqvae):
    H ,W ,C = image_shape
    data_tokens = vqvae.quantize(data)
    data_tokens = np.reshape(data_tokens, (data_tokens.shape[0], 7, 7))
    data_flat = np.reshape(data_tokens, (data_tokens.shape[0], -1)) # (batch_size, 49)
    
    dataset = torch.utils.data.TensorDataset(data_flat)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader
In [27]:
def q4_b(train_data, test_data, image_shape, dset_id, vqvae):
    """
    train_data: A (n_train, H, W, C) uint8 numpy array of color images with values in {0, 1, 2, 3}
    test_data: A (n_test, H, W, C) uint8 numpy array of color images with values in {0, 1, 2, 3}
    image_shape: (H, W, C), height, width, and # of channels of the image
    dset_id: An identifying number of which dataset is given (1 or 2). Most likely
            used to set different hyperparameters for different datasets
    vqvae: a vqvae model, trained on dataset dset_id

    Returns
    - a (# of training iterations,) numpy array of train_losses evaluated every minibatch
    - a (# of epochs + 1,) numpy array of test_losses evaluated once at initialization and after each epoch
    - a numpy array of size (100, H, C, W) of samples with values in {0, 1, 2, 3}
    """
    
    H, W, C = image_shape
    # initialize hyperparameters
    batch_size = 128
    learning_rate = 1e-3
    num_epochs = 30
    d_model = 128
    n_heads = 4
    n_layers = 4
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # determine sequence length and vocab size
    sequence_length = 7 * 7 + 1  # +1 for <bos> token
    vocab_size = vqvae.n_embeddings
    
    # create dataloaders
    train_loader = create_tokenized_data(train_data, image_shape, batch_size, vqvae)
    test_loader = create_tokenized_data(test_data, image_shape, batch_size, vqvae)
    
    # test the dataloader

    model = iGPT(vocab_size, sequence_length, d_model, n_heads, n_layers).to(device)
    
    train_losses, test_losses = train_igpt(model, train_loader, test_loader, 
                                            sequence_length, vocab_size, device,
                                            num_epochs, learning_rate)
    token_image_shape = (7,7,1)
    samples, _ = generate_samples(model, sequence_length, vocab_size, token_image_shape, device)
    # decode the samples
    print("samples shape: ", samples.shape)
    samples = samples.squeeze(-1)
    print("samples shape: ", samples.shape)
    samples = vqvae.decode(samples)
    
    return train_losses, test_losses, samples

Results¶

Once you've implemented q4_b, execute the cells below to visualize and save your results

In [28]:
q4b_save_results(1, q4_b)
data_dir:  /home/nghiaph/workspace/deepul/homeworks/hw1/data
data_dir:  /home/nghiaph/workspace/deepul/homeworks/hw1/data
Epoch 1/30, Batch 0/469, Loss: 7.1077
Epoch 1/30, Batch 50/469, Loss: 6.9837
Epoch 1/30, Batch 100/469, Loss: 6.3427
Epoch 1/30, Batch 150/469, Loss: 5.9725
Epoch 1/30, Batch 200/469, Loss: 5.8723
Epoch 1/30, Batch 250/469, Loss: 5.5979
Epoch 1/30, Batch 300/469, Loss: 5.3181
Epoch 1/30, Batch 350/469, Loss: 5.1249
Epoch 1/30, Batch 400/469, Loss: 4.9939
Epoch 1/30, Batch 450/469, Loss: 4.9383
Epoch 1/30 completed. Test Loss: 4.8370
Epoch 2/30, Batch 0/469, Loss: 4.8403
Epoch 2/30, Batch 50/469, Loss: 4.7477
Epoch 2/30, Batch 100/469, Loss: 4.7367
Epoch 2/30, Batch 150/469, Loss: 4.7285
Epoch 2/30, Batch 200/469, Loss: 4.6845
Epoch 2/30, Batch 250/469, Loss: 4.6479
Epoch 2/30, Batch 300/469, Loss: 4.5854
Epoch 2/30, Batch 350/469, Loss: 4.5794
Epoch 2/30, Batch 400/469, Loss: 4.5112
Epoch 2/30, Batch 450/469, Loss: 4.4655
Epoch 2/30 completed. Test Loss: 4.3398
Epoch 3/30, Batch 0/469, Loss: 4.3931
Epoch 3/30, Batch 50/469, Loss: 4.3437
Epoch 3/30, Batch 100/469, Loss: 4.3613
Epoch 3/30, Batch 150/469, Loss: 4.3126
Epoch 3/30, Batch 200/469, Loss: 4.3593
Epoch 3/30, Batch 250/469, Loss: 4.2809
Epoch 3/30, Batch 300/469, Loss: 4.2827
Epoch 3/30, Batch 350/469, Loss: 4.2148
Epoch 3/30, Batch 400/469, Loss: 4.1774
Epoch 3/30, Batch 450/469, Loss: 4.2373
Epoch 3/30 completed. Test Loss: 4.0282
Epoch 4/30, Batch 0/469, Loss: 4.1408
Epoch 4/30, Batch 50/469, Loss: 4.0903
Epoch 4/30, Batch 100/469, Loss: 4.0367
Epoch 4/30, Batch 150/469, Loss: 4.0997
Epoch 4/30, Batch 200/469, Loss: 4.0682
Epoch 4/30, Batch 250/469, Loss: 4.0462
Epoch 4/30, Batch 300/469, Loss: 4.0812
Epoch 4/30, Batch 350/469, Loss: 4.0188
Epoch 4/30, Batch 400/469, Loss: 4.0250
Epoch 4/30, Batch 450/469, Loss: 3.9967
Epoch 4/30 completed. Test Loss: 3.8461
Epoch 5/30, Batch 0/469, Loss: 3.8849
Epoch 5/30, Batch 50/469, Loss: 3.9655
Epoch 5/30, Batch 100/469, Loss: 3.9511
Epoch 5/30, Batch 150/469, Loss: 3.9692
Epoch 5/30, Batch 200/469, Loss: 3.9061
Epoch 5/30, Batch 250/469, Loss: 3.9744
Epoch 5/30, Batch 300/469, Loss: 3.8557
Epoch 5/30, Batch 350/469, Loss: 3.8580
Epoch 5/30, Batch 400/469, Loss: 3.9167
Epoch 5/30, Batch 450/469, Loss: 3.8119
Epoch 5/30 completed. Test Loss: 3.7142
Epoch 6/30, Batch 0/469, Loss: 3.8242
Epoch 6/30, Batch 50/469, Loss: 3.8132
Epoch 6/30, Batch 100/469, Loss: 3.9260
Epoch 6/30, Batch 150/469, Loss: 3.8289
Epoch 6/30, Batch 200/469, Loss: 3.7896
Epoch 6/30, Batch 250/469, Loss: 3.7953
Epoch 6/30, Batch 300/469, Loss: 3.7992
Epoch 6/30, Batch 350/469, Loss: 3.8512
Epoch 6/30, Batch 400/469, Loss: 3.7463
Epoch 6/30, Batch 450/469, Loss: 3.7673
Epoch 6/30 completed. Test Loss: 3.6134
Epoch 7/30, Batch 0/469, Loss: 3.7481
Epoch 7/30, Batch 50/469, Loss: 3.7231
Epoch 7/30, Batch 100/469, Loss: 3.7169
Epoch 7/30, Batch 150/469, Loss: 3.7061
Epoch 7/30, Batch 200/469, Loss: 3.7337
Epoch 7/30, Batch 250/469, Loss: 3.6886
Epoch 7/30, Batch 300/469, Loss: 3.6998
Epoch 7/30, Batch 350/469, Loss: 3.7366
Epoch 7/30, Batch 400/469, Loss: 3.6939
Epoch 7/30, Batch 450/469, Loss: 3.6395
Epoch 7/30 completed. Test Loss: 3.5308
Epoch 8/30, Batch 0/469, Loss: 3.6519
Epoch 8/30, Batch 50/469, Loss: 3.6976
Epoch 8/30, Batch 100/469, Loss: 3.6389
Epoch 8/30, Batch 150/469, Loss: 3.7058
Epoch 8/30, Batch 200/469, Loss: 3.6129
Epoch 8/30, Batch 250/469, Loss: 3.7350
Epoch 8/30, Batch 300/469, Loss: 3.6363
Epoch 8/30, Batch 350/469, Loss: 3.6249
Epoch 8/30, Batch 400/469, Loss: 3.5829
Epoch 8/30, Batch 450/469, Loss: 3.6489
Epoch 8/30 completed. Test Loss: 3.4716
Epoch 9/30, Batch 0/469, Loss: 3.5837
Epoch 9/30, Batch 50/469, Loss: 3.6476
Epoch 9/30, Batch 100/469, Loss: 3.5756
Epoch 9/30, Batch 150/469, Loss: 3.6084
Epoch 9/30, Batch 200/469, Loss: 3.6124
Epoch 9/30, Batch 250/469, Loss: 3.6120
Epoch 9/30, Batch 300/469, Loss: 3.6157
Epoch 9/30, Batch 350/469, Loss: 3.6508
Epoch 9/30, Batch 400/469, Loss: 3.6172
Epoch 9/30, Batch 450/469, Loss: 3.5900
Epoch 9/30 completed. Test Loss: 3.4200
Epoch 10/30, Batch 0/469, Loss: 3.5248
Epoch 10/30, Batch 50/469, Loss: 3.5483
Epoch 10/30, Batch 100/469, Loss: 3.6200
Epoch 10/30, Batch 150/469, Loss: 3.5654
Epoch 10/30, Batch 200/469, Loss: 3.5237
Epoch 10/30, Batch 250/469, Loss: 3.5251
Epoch 10/30, Batch 300/469, Loss: 3.5507
Epoch 10/30, Batch 350/469, Loss: 3.5162
Epoch 10/30, Batch 400/469, Loss: 3.4542
Epoch 10/30, Batch 450/469, Loss: 3.5177
Epoch 10/30 completed. Test Loss: 3.3708
Epoch 11/30, Batch 0/469, Loss: 3.4839
Epoch 11/30, Batch 50/469, Loss: 3.5139
Epoch 11/30, Batch 100/469, Loss: 3.5196
Epoch 11/30, Batch 150/469, Loss: 3.4121
Epoch 11/30, Batch 200/469, Loss: 3.5378
Epoch 11/30, Batch 250/469, Loss: 3.5556
Epoch 11/30, Batch 300/469, Loss: 3.5336
Epoch 11/30, Batch 350/469, Loss: 3.5014
Epoch 11/30, Batch 400/469, Loss: 3.5298
Epoch 11/30, Batch 450/469, Loss: 3.5157
Epoch 11/30 completed. Test Loss: 3.3380
Epoch 12/30, Batch 0/469, Loss: 3.5097
Epoch 12/30, Batch 50/469, Loss: 3.4858
Epoch 12/30, Batch 100/469, Loss: 3.4831
Epoch 12/30, Batch 150/469, Loss: 3.4702
Epoch 12/30, Batch 200/469, Loss: 3.5245
Epoch 12/30, Batch 250/469, Loss: 3.4701
Epoch 12/30, Batch 300/469, Loss: 3.4306
Epoch 12/30, Batch 350/469, Loss: 3.5069
Epoch 12/30, Batch 400/469, Loss: 3.5154
Epoch 12/30, Batch 450/469, Loss: 3.4283
Epoch 12/30 completed. Test Loss: 3.3064
Epoch 13/30, Batch 0/469, Loss: 3.4527
Epoch 13/30, Batch 50/469, Loss: 3.4045
Epoch 13/30, Batch 100/469, Loss: 3.4684
Epoch 13/30, Batch 150/469, Loss: 3.4400
Epoch 13/30, Batch 200/469, Loss: 3.4590
Epoch 13/30, Batch 250/469, Loss: 3.3944
Epoch 13/30, Batch 300/469, Loss: 3.4304
Epoch 13/30, Batch 350/469, Loss: 3.4733
Epoch 13/30, Batch 400/469, Loss: 3.4694
Epoch 13/30, Batch 450/469, Loss: 3.4476
Epoch 13/30 completed. Test Loss: 3.2770
Epoch 14/30, Batch 0/469, Loss: 3.3833
Epoch 14/30, Batch 50/469, Loss: 3.4388
Epoch 14/30, Batch 100/469, Loss: 3.3825
Epoch 14/30, Batch 150/469, Loss: 3.4433
Epoch 14/30, Batch 200/469, Loss: 3.4666
Epoch 14/30, Batch 250/469, Loss: 3.4811
Epoch 14/30, Batch 300/469, Loss: 3.4056
Epoch 14/30, Batch 350/469, Loss: 3.4733
Epoch 14/30, Batch 400/469, Loss: 3.4005
Epoch 14/30, Batch 450/469, Loss: 3.4185
Epoch 14/30 completed. Test Loss: 3.2516
Epoch 15/30, Batch 0/469, Loss: 3.4239
Epoch 15/30, Batch 50/469, Loss: 3.3823
Epoch 15/30, Batch 100/469, Loss: 3.3771
Epoch 15/30, Batch 150/469, Loss: 3.3600
Epoch 15/30, Batch 200/469, Loss: 3.3744
Epoch 15/30, Batch 250/469, Loss: 3.4164
Epoch 15/30, Batch 300/469, Loss: 3.4302
Epoch 15/30, Batch 350/469, Loss: 3.4463
Epoch 15/30, Batch 400/469, Loss: 3.4218
Epoch 15/30, Batch 450/469, Loss: 3.3389
Epoch 15/30 completed. Test Loss: 3.2330
Epoch 16/30, Batch 0/469, Loss: 3.3724
Epoch 16/30, Batch 50/469, Loss: 3.3766
Epoch 16/30, Batch 100/469, Loss: 3.4000
Epoch 16/30, Batch 150/469, Loss: 3.4826
Epoch 16/30, Batch 200/469, Loss: 3.3904
Epoch 16/30, Batch 250/469, Loss: 3.3898
Epoch 16/30, Batch 300/469, Loss: 3.3781
Epoch 16/30, Batch 350/469, Loss: 3.4230
Epoch 16/30, Batch 400/469, Loss: 3.3931
Epoch 16/30, Batch 450/469, Loss: 3.3893
Epoch 16/30 completed. Test Loss: 3.2147
Epoch 17/30, Batch 0/469, Loss: 3.4125
Epoch 17/30, Batch 50/469, Loss: 3.3491
Epoch 17/30, Batch 100/469, Loss: 3.3792
Epoch 17/30, Batch 150/469, Loss: 3.3654
Epoch 17/30, Batch 200/469, Loss: 3.3137
Epoch 17/30, Batch 250/469, Loss: 3.3725
Epoch 17/30, Batch 300/469, Loss: 3.4090
Epoch 17/30, Batch 350/469, Loss: 3.3514
Epoch 17/30, Batch 400/469, Loss: 3.3518
Epoch 17/30, Batch 450/469, Loss: 3.3111
Epoch 17/30 completed. Test Loss: 3.1994
Epoch 18/30, Batch 0/469, Loss: 3.3238
Epoch 18/30, Batch 50/469, Loss: 3.3128
Epoch 18/30, Batch 100/469, Loss: 3.4004
Epoch 18/30, Batch 150/469, Loss: 3.3306
Epoch 18/30, Batch 200/469, Loss: 3.2622
Epoch 18/30, Batch 250/469, Loss: 3.3982
Epoch 18/30, Batch 300/469, Loss: 3.3378
Epoch 18/30, Batch 350/469, Loss: 3.3567
Epoch 18/30, Batch 400/469, Loss: 3.3094
Epoch 18/30, Batch 450/469, Loss: 3.3316
Epoch 18/30 completed. Test Loss: 3.1846
Epoch 19/30, Batch 0/469, Loss: 3.3406
Epoch 19/30, Batch 50/469, Loss: 3.2802
Epoch 19/30, Batch 100/469, Loss: 3.3789
Epoch 19/30, Batch 150/469, Loss: 3.3395
Epoch 19/30, Batch 200/469, Loss: 3.3386
Epoch 19/30, Batch 250/469, Loss: 3.3567
Epoch 19/30, Batch 300/469, Loss: 3.3761
Epoch 19/30, Batch 350/469, Loss: 3.3530
Epoch 19/30, Batch 400/469, Loss: 3.3722
Epoch 19/30, Batch 450/469, Loss: 3.2983
Epoch 19/30 completed. Test Loss: 3.1693
Epoch 20/30, Batch 0/469, Loss: 3.3158
Epoch 20/30, Batch 50/469, Loss: 3.2793
Epoch 20/30, Batch 100/469, Loss: 3.3011
Epoch 20/30, Batch 150/469, Loss: 3.3184
Epoch 20/30, Batch 200/469, Loss: 3.3578
Epoch 20/30, Batch 250/469, Loss: 3.2866
Epoch 20/30, Batch 300/469, Loss: 3.3581
Epoch 20/30, Batch 350/469, Loss: 3.3815
Epoch 20/30, Batch 400/469, Loss: 3.3537
Epoch 20/30, Batch 450/469, Loss: 3.3523
Epoch 20/30 completed. Test Loss: 3.1566
Epoch 21/30, Batch 0/469, Loss: 3.3200
Epoch 21/30, Batch 50/469, Loss: 3.2771
Epoch 21/30, Batch 100/469, Loss: 3.3150
Epoch 21/30, Batch 150/469, Loss: 3.2530
Epoch 21/30, Batch 200/469, Loss: 3.2572
Epoch 21/30, Batch 250/469, Loss: 3.3431
Epoch 21/30, Batch 300/469, Loss: 3.2836
Epoch 21/30, Batch 350/469, Loss: 3.2962
Epoch 21/30, Batch 400/469, Loss: 3.3519
Epoch 21/30, Batch 450/469, Loss: 3.3681
Epoch 21/30 completed. Test Loss: 3.1470
Epoch 22/30, Batch 0/469, Loss: 3.3074
Epoch 22/30, Batch 50/469, Loss: 3.2974
Epoch 22/30, Batch 100/469, Loss: 3.2479
Epoch 22/30, Batch 150/469, Loss: 3.2646
Epoch 22/30, Batch 200/469, Loss: 3.2972
Epoch 22/30, Batch 250/469, Loss: 3.2936
Epoch 22/30, Batch 300/469, Loss: 3.2462
Epoch 22/30, Batch 350/469, Loss: 3.3011
Epoch 22/30, Batch 400/469, Loss: 3.2563
Epoch 22/30, Batch 450/469, Loss: 3.3189
Epoch 22/30 completed. Test Loss: 3.1397
Epoch 23/30, Batch 0/469, Loss: 3.2839
Epoch 23/30, Batch 50/469, Loss: 3.2816
Epoch 23/30, Batch 100/469, Loss: 3.2867
Epoch 23/30, Batch 150/469, Loss: 3.2996
Epoch 23/30, Batch 200/469, Loss: 3.2771
Epoch 23/30, Batch 250/469, Loss: 3.3240
Epoch 23/30, Batch 300/469, Loss: 3.2692
Epoch 23/30, Batch 350/469, Loss: 3.2393
Epoch 23/30, Batch 400/469, Loss: 3.3506
Epoch 23/30, Batch 450/469, Loss: 3.2098
Epoch 23/30 completed. Test Loss: 3.1326
Epoch 24/30, Batch 0/469, Loss: 3.2436
Epoch 24/30, Batch 50/469, Loss: 3.2608
Epoch 24/30, Batch 100/469, Loss: 3.2284
Epoch 24/30, Batch 150/469, Loss: 3.3140
Epoch 24/30, Batch 200/469, Loss: 3.3108
Epoch 24/30, Batch 250/469, Loss: 3.2630
Epoch 24/30, Batch 300/469, Loss: 3.3307
Epoch 24/30, Batch 350/469, Loss: 3.2905
Epoch 24/30, Batch 400/469, Loss: 3.1866
Epoch 24/30, Batch 450/469, Loss: 3.2867
Epoch 24/30 completed. Test Loss: 3.1256
Epoch 25/30, Batch 0/469, Loss: 3.2576
Epoch 25/30, Batch 50/469, Loss: 3.2438
Epoch 25/30, Batch 100/469, Loss: 3.3221
Epoch 25/30, Batch 150/469, Loss: 3.2446
Epoch 25/30, Batch 200/469, Loss: 3.2813
Epoch 25/30, Batch 250/469, Loss: 3.2923
Epoch 25/30, Batch 300/469, Loss: 3.2792
Epoch 25/30, Batch 350/469, Loss: 3.2515
Epoch 25/30, Batch 400/469, Loss: 3.2860
Epoch 25/30, Batch 450/469, Loss: 3.2242
Epoch 25/30 completed. Test Loss: 3.1216
Epoch 26/30, Batch 0/469, Loss: 3.3549
Epoch 26/30, Batch 50/469, Loss: 3.2618
Epoch 26/30, Batch 100/469, Loss: 3.3040
Epoch 26/30, Batch 150/469, Loss: 3.1989
Epoch 26/30, Batch 200/469, Loss: 3.2737
Epoch 26/30, Batch 250/469, Loss: 3.2807
Epoch 26/30, Batch 300/469, Loss: 3.2195
Epoch 26/30, Batch 350/469, Loss: 3.3169
Epoch 26/30, Batch 400/469, Loss: 3.2727
Epoch 26/30, Batch 450/469, Loss: 3.2274
Epoch 26/30 completed. Test Loss: 3.1182
Epoch 27/30, Batch 0/469, Loss: 3.2861
Epoch 27/30, Batch 50/469, Loss: 3.1659
Epoch 27/30, Batch 100/469, Loss: 3.2290
Epoch 27/30, Batch 150/469, Loss: 3.2046
Epoch 27/30, Batch 200/469, Loss: 3.2605
Epoch 27/30, Batch 250/469, Loss: 3.2226
Epoch 27/30, Batch 300/469, Loss: 3.2486
Epoch 27/30, Batch 350/469, Loss: 3.2122
Epoch 27/30, Batch 400/469, Loss: 3.2031
Epoch 27/30, Batch 450/469, Loss: 3.2565
Epoch 27/30 completed. Test Loss: 3.1165
Epoch 28/30, Batch 0/469, Loss: 3.3085
Epoch 28/30, Batch 50/469, Loss: 3.2284
Epoch 28/30, Batch 100/469, Loss: 3.2412
Epoch 28/30, Batch 150/469, Loss: 3.2820
Epoch 28/30, Batch 200/469, Loss: 3.2468
Epoch 28/30, Batch 250/469, Loss: 3.2988
Epoch 28/30, Batch 300/469, Loss: 3.3192
Epoch 28/30, Batch 350/469, Loss: 3.2914
Epoch 28/30, Batch 400/469, Loss: 3.2607
Epoch 28/30, Batch 450/469, Loss: 3.2927
Epoch 28/30 completed. Test Loss: 3.1154
Epoch 29/30, Batch 0/469, Loss: 3.2820
Epoch 29/30, Batch 50/469, Loss: 3.2303
Epoch 29/30, Batch 100/469, Loss: 3.1826
Epoch 29/30, Batch 150/469, Loss: 3.2421
Epoch 29/30, Batch 200/469, Loss: 3.2241
Epoch 29/30, Batch 250/469, Loss: 3.3153
Epoch 29/30, Batch 300/469, Loss: 3.2545
Epoch 29/30, Batch 350/469, Loss: 3.1512
Epoch 29/30, Batch 400/469, Loss: 3.3258
Epoch 29/30, Batch 450/469, Loss: 3.2103
Epoch 29/30 completed. Test Loss: 3.1145
Epoch 30/30, Batch 0/469, Loss: 3.2814
Epoch 30/30, Batch 50/469, Loss: 3.2195
Epoch 30/30, Batch 100/469, Loss: 3.2526
Epoch 30/30, Batch 150/469, Loss: 3.2314
Epoch 30/30, Batch 200/469, Loss: 3.2903
Epoch 30/30, Batch 250/469, Loss: 3.2468
Epoch 30/30, Batch 300/469, Loss: 3.2556
Epoch 30/30, Batch 350/469, Loss: 3.2143
Epoch 30/30, Batch 400/469, Loss: 3.3093
Epoch 30/30, Batch 450/469, Loss: 3.1963
Epoch 30/30 completed. Test Loss: 3.1144
samples shape:  (100, 7, 7, 1)
samples shape:  (100, 7, 7)
Final Test Loss: 3.1144
No description has been provided for this image
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.13376339..1.283002].
samples shape:  (100, 28, 28, 3)
No description has been provided for this image
In [29]:
q4b_save_results(2, q4_b)
data_dir:  /home/nghiaph/workspace/deepul/homeworks/hw1/data
data_dir:  /home/nghiaph/workspace/deepul/homeworks/hw1/data
Epoch 1/30, Batch 0/469, Loss: 7.1028
Epoch 1/30, Batch 50/469, Loss: 7.0020
Epoch 1/30, Batch 100/469, Loss: 6.5681
Epoch 1/30, Batch 150/469, Loss: 6.2534
Epoch 1/30, Batch 200/469, Loss: 5.9077
Epoch 1/30, Batch 250/469, Loss: 5.1703
Epoch 1/30, Batch 300/469, Loss: 4.6615
Epoch 1/30, Batch 350/469, Loss: 4.3008
Epoch 1/30, Batch 400/469, Loss: 4.2373
Epoch 1/30, Batch 450/469, Loss: 4.1927
Epoch 1/30 completed. Test Loss: 4.0851
Epoch 2/30, Batch 0/469, Loss: 4.1059
Epoch 2/30, Batch 50/469, Loss: 4.0950
Epoch 2/30, Batch 100/469, Loss: 4.1215
Epoch 2/30, Batch 150/469, Loss: 4.1579
Epoch 2/30, Batch 200/469, Loss: 3.9415
Epoch 2/30, Batch 250/469, Loss: 3.9819
Epoch 2/30, Batch 300/469, Loss: 3.9069
Epoch 2/30, Batch 350/469, Loss: 4.0260
Epoch 2/30, Batch 400/469, Loss: 3.9660
Epoch 2/30, Batch 450/469, Loss: 3.9541
Epoch 2/30 completed. Test Loss: 3.8060
Epoch 3/30, Batch 0/469, Loss: 3.8758
Epoch 3/30, Batch 50/469, Loss: 3.8279
Epoch 3/30, Batch 100/469, Loss: 3.7430
Epoch 3/30, Batch 150/469, Loss: 3.7442
Epoch 3/30, Batch 200/469, Loss: 3.7346
Epoch 3/30, Batch 250/469, Loss: 3.7924
Epoch 3/30, Batch 300/469, Loss: 3.7722
Epoch 3/30, Batch 350/469, Loss: 3.6915
Epoch 3/30, Batch 400/469, Loss: 3.5741
Epoch 3/30, Batch 450/469, Loss: 3.6316
Epoch 3/30 completed. Test Loss: 3.5549
Epoch 4/30, Batch 0/469, Loss: 3.6727
Epoch 4/30, Batch 50/469, Loss: 3.6121
Epoch 4/30, Batch 100/469, Loss: 3.6661
Epoch 4/30, Batch 150/469, Loss: 3.5980
Epoch 4/30, Batch 200/469, Loss: 3.6620
Epoch 4/30, Batch 250/469, Loss: 3.4974
Epoch 4/30, Batch 300/469, Loss: 3.4834
Epoch 4/30, Batch 350/469, Loss: 3.4665
Epoch 4/30, Batch 400/469, Loss: 3.4790
Epoch 4/30, Batch 450/469, Loss: 3.5781
Epoch 4/30 completed. Test Loss: 3.4005
Epoch 5/30, Batch 0/469, Loss: 3.4673
Epoch 5/30, Batch 50/469, Loss: 3.4440
Epoch 5/30, Batch 100/469, Loss: 3.5104
Epoch 5/30, Batch 150/469, Loss: 3.4669
Epoch 5/30, Batch 200/469, Loss: 3.4606
Epoch 5/30, Batch 250/469, Loss: 3.4286
Epoch 5/30, Batch 300/469, Loss: 3.3574
Epoch 5/30, Batch 350/469, Loss: 3.4290
Epoch 5/30, Batch 400/469, Loss: 3.4704
Epoch 5/30, Batch 450/469, Loss: 3.4133
Epoch 5/30 completed. Test Loss: 3.3082
Epoch 6/30, Batch 0/469, Loss: 3.3975
Epoch 6/30, Batch 50/469, Loss: 3.4547
Epoch 6/30, Batch 100/469, Loss: 3.3121
Epoch 6/30, Batch 150/469, Loss: 3.3135
Epoch 6/30, Batch 200/469, Loss: 3.3450
Epoch 6/30, Batch 250/469, Loss: 3.4129
Epoch 6/30, Batch 300/469, Loss: 3.4192
Epoch 6/30, Batch 350/469, Loss: 3.3552
Epoch 6/30, Batch 400/469, Loss: 3.2981
Epoch 6/30, Batch 450/469, Loss: 3.2434
Epoch 6/30 completed. Test Loss: 3.2377
Epoch 7/30, Batch 0/469, Loss: 3.3043
Epoch 7/30, Batch 50/469, Loss: 3.3149
Epoch 7/30, Batch 100/469, Loss: 3.3291
Epoch 7/30, Batch 150/469, Loss: 3.3469
Epoch 7/30, Batch 200/469, Loss: 3.3369
Epoch 7/30, Batch 250/469, Loss: 3.3725
Epoch 7/30, Batch 300/469, Loss: 3.3349
Epoch 7/30, Batch 350/469, Loss: 3.2290
Epoch 7/30, Batch 400/469, Loss: 3.2064
Epoch 7/30, Batch 450/469, Loss: 3.2823
Epoch 7/30 completed. Test Loss: 3.1858
Epoch 8/30, Batch 0/469, Loss: 3.2766
Epoch 8/30, Batch 50/469, Loss: 3.2861
Epoch 8/30, Batch 100/469, Loss: 3.2250
Epoch 8/30, Batch 150/469, Loss: 3.2699
Epoch 8/30, Batch 200/469, Loss: 3.2300
Epoch 8/30, Batch 250/469, Loss: 3.2301
Epoch 8/30, Batch 300/469, Loss: 3.2668
Epoch 8/30, Batch 350/469, Loss: 3.2865
Epoch 8/30, Batch 400/469, Loss: 3.1757
Epoch 8/30, Batch 450/469, Loss: 3.2293
Epoch 8/30 completed. Test Loss: 3.1492
Epoch 9/30, Batch 0/469, Loss: 3.2163
Epoch 9/30, Batch 50/469, Loss: 3.2104
Epoch 9/30, Batch 100/469, Loss: 3.2777
Epoch 9/30, Batch 150/469, Loss: 3.2420
Epoch 9/30, Batch 200/469, Loss: 3.2461
Epoch 9/30, Batch 250/469, Loss: 3.2567
Epoch 9/30, Batch 300/469, Loss: 3.1962
Epoch 9/30, Batch 350/469, Loss: 3.1944
Epoch 9/30, Batch 400/469, Loss: 3.1820
Epoch 9/30, Batch 450/469, Loss: 3.2126
Epoch 9/30 completed. Test Loss: 3.1186
Epoch 10/30, Batch 0/469, Loss: 3.1953
Epoch 10/30, Batch 50/469, Loss: 3.1867
Epoch 10/30, Batch 100/469, Loss: 3.1499
Epoch 10/30, Batch 150/469, Loss: 3.2222
Epoch 10/30, Batch 200/469, Loss: 3.1960
Epoch 10/30, Batch 250/469, Loss: 3.2125
Epoch 10/30, Batch 300/469, Loss: 3.1431
Epoch 10/30, Batch 350/469, Loss: 3.2460
Epoch 10/30, Batch 400/469, Loss: 3.1780
Epoch 10/30, Batch 450/469, Loss: 3.1525
Epoch 10/30 completed. Test Loss: 3.0947
Epoch 11/30, Batch 0/469, Loss: 3.2014
Epoch 11/30, Batch 50/469, Loss: 3.1994
Epoch 11/30, Batch 100/469, Loss: 3.1477
Epoch 11/30, Batch 150/469, Loss: 3.2226
Epoch 11/30, Batch 200/469, Loss: 3.2257
Epoch 11/30, Batch 250/469, Loss: 3.0542
Epoch 11/30, Batch 300/469, Loss: 3.1258
Epoch 11/30, Batch 350/469, Loss: 3.1137
Epoch 11/30, Batch 400/469, Loss: 3.1760
Epoch 11/30, Batch 450/469, Loss: 3.0894
Epoch 11/30 completed. Test Loss: 3.0735
Epoch 12/30, Batch 0/469, Loss: 3.0979
Epoch 12/30, Batch 50/469, Loss: 3.1667
Epoch 12/30, Batch 100/469, Loss: 3.1316
Epoch 12/30, Batch 150/469, Loss: 3.1550
Epoch 12/30, Batch 200/469, Loss: 3.0975
Epoch 12/30, Batch 250/469, Loss: 3.1202
Epoch 12/30, Batch 300/469, Loss: 3.0907
Epoch 12/30, Batch 350/469, Loss: 3.1181
Epoch 12/30, Batch 400/469, Loss: 3.0974
Epoch 12/30, Batch 450/469, Loss: 3.1560
Epoch 12/30 completed. Test Loss: 3.0577
Epoch 13/30, Batch 0/469, Loss: 3.0815
Epoch 13/30, Batch 50/469, Loss: 3.1552
Epoch 13/30, Batch 100/469, Loss: 3.0795
Epoch 13/30, Batch 150/469, Loss: 3.1172
Epoch 13/30, Batch 200/469, Loss: 3.1367
Epoch 13/30, Batch 250/469, Loss: 3.1360
Epoch 13/30, Batch 300/469, Loss: 3.0928
Epoch 13/30, Batch 350/469, Loss: 3.1047
Epoch 13/30, Batch 400/469, Loss: 3.2051
Epoch 13/30, Batch 450/469, Loss: 3.1227
Epoch 13/30 completed. Test Loss: 3.0422
Epoch 14/30, Batch 0/469, Loss: 3.0632
Epoch 14/30, Batch 50/469, Loss: 3.1038
Epoch 14/30, Batch 100/469, Loss: 3.1557
Epoch 14/30, Batch 150/469, Loss: 3.1844
Epoch 14/30, Batch 200/469, Loss: 3.1093
Epoch 14/30, Batch 250/469, Loss: 3.1344
Epoch 14/30, Batch 300/469, Loss: 3.0865
Epoch 14/30, Batch 350/469, Loss: 3.1240
Epoch 14/30, Batch 400/469, Loss: 3.1122
Epoch 14/30, Batch 450/469, Loss: 3.1129
Epoch 14/30 completed. Test Loss: 3.0321
Epoch 15/30, Batch 0/469, Loss: 3.0486
Epoch 15/30, Batch 50/469, Loss: 3.0854
Epoch 15/30, Batch 100/469, Loss: 3.0994
Epoch 15/30, Batch 150/469, Loss: 3.1074
Epoch 15/30, Batch 200/469, Loss: 3.0560
Epoch 15/30, Batch 250/469, Loss: 3.0814
Epoch 15/30, Batch 300/469, Loss: 3.1589
Epoch 15/30, Batch 350/469, Loss: 3.0332
Epoch 15/30, Batch 400/469, Loss: 3.0877
Epoch 15/30, Batch 450/469, Loss: 3.0549
Epoch 15/30 completed. Test Loss: 3.0209
Epoch 16/30, Batch 0/469, Loss: 3.1523
Epoch 16/30, Batch 50/469, Loss: 3.0517
Epoch 16/30, Batch 100/469, Loss: 3.1639
Epoch 16/30, Batch 150/469, Loss: 3.0564
Epoch 16/30, Batch 200/469, Loss: 3.0489
Epoch 16/30, Batch 250/469, Loss: 3.0244
Epoch 16/30, Batch 300/469, Loss: 3.0933
Epoch 16/30, Batch 350/469, Loss: 3.1153
Epoch 16/30, Batch 400/469, Loss: 3.0665
Epoch 16/30, Batch 450/469, Loss: 3.2082
Epoch 16/30 completed. Test Loss: 3.0118
Epoch 17/30, Batch 0/469, Loss: 3.0703
Epoch 17/30, Batch 50/469, Loss: 3.0688
Epoch 17/30, Batch 100/469, Loss: 3.0822
Epoch 17/30, Batch 150/469, Loss: 3.0145
Epoch 17/30, Batch 200/469, Loss: 3.0619
Epoch 17/30, Batch 250/469, Loss: 3.1068
Epoch 17/30, Batch 300/469, Loss: 3.0871
Epoch 17/30, Batch 350/469, Loss: 3.1040
Epoch 17/30, Batch 400/469, Loss: 3.0474
Epoch 17/30, Batch 450/469, Loss: 3.1062
Epoch 17/30 completed. Test Loss: 2.9991
Epoch 18/30, Batch 0/469, Loss: 3.0139
Epoch 18/30, Batch 50/469, Loss: 2.9685
Epoch 18/30, Batch 100/469, Loss: 2.9834
Epoch 18/30, Batch 150/469, Loss: 3.0927
Epoch 18/30, Batch 200/469, Loss: 3.0285
Epoch 18/30, Batch 250/469, Loss: 3.1719
Epoch 18/30, Batch 300/469, Loss: 3.0437
Epoch 18/30, Batch 350/469, Loss: 3.0007
Epoch 18/30, Batch 400/469, Loss: 3.0559
Epoch 18/30, Batch 450/469, Loss: 3.0327
Epoch 18/30 completed. Test Loss: 2.9894
Epoch 19/30, Batch 0/469, Loss: 3.0575
Epoch 19/30, Batch 50/469, Loss: 3.1054
Epoch 19/30, Batch 100/469, Loss: 2.9840
Epoch 19/30, Batch 150/469, Loss: 3.1167
Epoch 19/30, Batch 200/469, Loss: 3.0414
Epoch 19/30, Batch 250/469, Loss: 3.0328
Epoch 19/30, Batch 300/469, Loss: 3.1417
Epoch 19/30, Batch 350/469, Loss: 3.0676
Epoch 19/30, Batch 400/469, Loss: 3.0764
Epoch 19/30, Batch 450/469, Loss: 3.0550
Epoch 19/30 completed. Test Loss: 2.9796
Epoch 20/30, Batch 0/469, Loss: 3.0773
Epoch 20/30, Batch 50/469, Loss: 3.0147
Epoch 20/30, Batch 100/469, Loss: 2.9797
Epoch 20/30, Batch 150/469, Loss: 3.0323
Epoch 20/30, Batch 200/469, Loss: 3.0504
Epoch 20/30, Batch 250/469, Loss: 3.1142
Epoch 20/30, Batch 300/469, Loss: 2.9954
Epoch 20/30, Batch 350/469, Loss: 3.0329
Epoch 20/30, Batch 400/469, Loss: 3.0967
Epoch 20/30, Batch 450/469, Loss: 2.9793
Epoch 20/30 completed. Test Loss: 2.9761
Epoch 21/30, Batch 0/469, Loss: 3.0671
Epoch 21/30, Batch 50/469, Loss: 3.0186
Epoch 21/30, Batch 100/469, Loss: 3.0481
Epoch 21/30, Batch 150/469, Loss: 3.0473
Epoch 21/30, Batch 200/469, Loss: 2.9966
Epoch 21/30, Batch 250/469, Loss: 2.9813
Epoch 21/30, Batch 300/469, Loss: 3.0101
Epoch 21/30, Batch 350/469, Loss: 3.0607
Epoch 21/30, Batch 400/469, Loss: 3.0566
Epoch 21/30, Batch 450/469, Loss: 3.0528
Epoch 21/30 completed. Test Loss: 2.9700
Epoch 22/30, Batch 0/469, Loss: 3.0110
Epoch 22/30, Batch 50/469, Loss: 2.9996
Epoch 22/30, Batch 100/469, Loss: 3.0245
Epoch 22/30, Batch 150/469, Loss: 2.9715
Epoch 22/30, Batch 200/469, Loss: 3.0241
Epoch 22/30, Batch 250/469, Loss: 2.9922
Epoch 22/30, Batch 300/469, Loss: 2.9814
Epoch 22/30, Batch 350/469, Loss: 3.1051
Epoch 22/30, Batch 400/469, Loss: 3.0265
Epoch 22/30, Batch 450/469, Loss: 3.0517
Epoch 22/30 completed. Test Loss: 2.9665
Epoch 23/30, Batch 0/469, Loss: 3.0043
Epoch 23/30, Batch 50/469, Loss: 2.9430
Epoch 23/30, Batch 100/469, Loss: 2.9871
Epoch 23/30, Batch 150/469, Loss: 3.0060
Epoch 23/30, Batch 200/469, Loss: 2.9903
Epoch 23/30, Batch 250/469, Loss: 3.0533
Epoch 23/30, Batch 300/469, Loss: 2.9526
Epoch 23/30, Batch 350/469, Loss: 3.0467
Epoch 23/30, Batch 400/469, Loss: 3.0433
Epoch 23/30, Batch 450/469, Loss: 2.9706
Epoch 23/30 completed. Test Loss: 2.9617
Epoch 24/30, Batch 0/469, Loss: 2.9814
Epoch 24/30, Batch 50/469, Loss: 2.9874
Epoch 24/30, Batch 100/469, Loss: 2.9733
Epoch 24/30, Batch 150/469, Loss: 3.0548
Epoch 24/30, Batch 200/469, Loss: 3.0537
Epoch 24/30, Batch 250/469, Loss: 2.9989
Epoch 24/30, Batch 300/469, Loss: 3.0660
Epoch 24/30, Batch 350/469, Loss: 2.9604
Epoch 24/30, Batch 400/469, Loss: 3.1869
Epoch 24/30, Batch 450/469, Loss: 2.9720
Epoch 24/30 completed. Test Loss: 2.9579
Epoch 25/30, Batch 0/469, Loss: 3.0287
Epoch 25/30, Batch 50/469, Loss: 3.0835
Epoch 25/30, Batch 100/469, Loss: 3.0299
Epoch 25/30, Batch 150/469, Loss: 2.9867
Epoch 25/30, Batch 200/469, Loss: 3.0570
Epoch 25/30, Batch 250/469, Loss: 3.0093
Epoch 25/30, Batch 300/469, Loss: 3.0596
Epoch 25/30, Batch 350/469, Loss: 3.0245
Epoch 25/30, Batch 400/469, Loss: 3.0338
Epoch 25/30, Batch 450/469, Loss: 2.9632
Epoch 25/30 completed. Test Loss: 2.9569
Epoch 26/30, Batch 0/469, Loss: 2.9927
Epoch 26/30, Batch 50/469, Loss: 2.9575
Epoch 26/30, Batch 100/469, Loss: 3.0357
Epoch 26/30, Batch 150/469, Loss: 2.9555
Epoch 26/30, Batch 200/469, Loss: 2.9541
Epoch 26/30, Batch 250/469, Loss: 3.0262
Epoch 26/30, Batch 300/469, Loss: 3.1024
Epoch 26/30, Batch 350/469, Loss: 2.9882
Epoch 26/30, Batch 400/469, Loss: 3.0074
Epoch 26/30, Batch 450/469, Loss: 2.9454
Epoch 26/30 completed. Test Loss: 2.9539
Epoch 27/30, Batch 0/469, Loss: 3.0561
Epoch 27/30, Batch 50/469, Loss: 3.0346
Epoch 27/30, Batch 100/469, Loss: 2.9425
Epoch 27/30, Batch 150/469, Loss: 3.0430
Epoch 27/30, Batch 200/469, Loss: 3.0076
Epoch 27/30, Batch 250/469, Loss: 2.9893
Epoch 27/30, Batch 300/469, Loss: 3.0603
Epoch 27/30, Batch 350/469, Loss: 3.0235
Epoch 27/30, Batch 400/469, Loss: 3.0288
Epoch 27/30, Batch 450/469, Loss: 3.0160
Epoch 27/30 completed. Test Loss: 2.9521
Epoch 28/30, Batch 0/469, Loss: 2.9329
Epoch 28/30, Batch 50/469, Loss: 3.0001
Epoch 28/30, Batch 100/469, Loss: 2.9583
Epoch 28/30, Batch 150/469, Loss: 3.0459
Epoch 28/30, Batch 200/469, Loss: 3.0115
Epoch 28/30, Batch 250/469, Loss: 2.9163
Epoch 28/30, Batch 300/469, Loss: 3.0544
Epoch 28/30, Batch 350/469, Loss: 2.9707
Epoch 28/30, Batch 400/469, Loss: 2.9469
Epoch 28/30, Batch 450/469, Loss: 3.0173
Epoch 28/30 completed. Test Loss: 2.9519
Epoch 29/30, Batch 0/469, Loss: 3.0644
Epoch 29/30, Batch 50/469, Loss: 2.9810
Epoch 29/30, Batch 100/469, Loss: 2.9787
Epoch 29/30, Batch 150/469, Loss: 2.9147
Epoch 29/30, Batch 200/469, Loss: 3.0088
Epoch 29/30, Batch 250/469, Loss: 3.0743
Epoch 29/30, Batch 300/469, Loss: 3.0240
Epoch 29/30, Batch 350/469, Loss: 3.0339
Epoch 29/30, Batch 400/469, Loss: 3.0308
Epoch 29/30, Batch 450/469, Loss: 2.9863
Epoch 29/30 completed. Test Loss: 2.9516
Epoch 30/30, Batch 0/469, Loss: 3.0039
Epoch 30/30, Batch 50/469, Loss: 2.9561
Epoch 30/30, Batch 100/469, Loss: 3.0510
Epoch 30/30, Batch 150/469, Loss: 2.9257
Epoch 30/30, Batch 200/469, Loss: 2.9440
Epoch 30/30, Batch 250/469, Loss: 2.9676
Epoch 30/30, Batch 300/469, Loss: 3.1028
Epoch 30/30, Batch 350/469, Loss: 2.9682
Epoch 30/30, Batch 400/469, Loss: 2.9149
Epoch 30/30, Batch 450/469, Loss: 2.9851
Epoch 30/30 completed. Test Loss: 2.9514
samples shape:  (100, 7, 7, 1)
samples shape:  (100, 7, 7)
Final Test Loss: 2.9514
No description has been provided for this image
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.23252869..1.2633749].
samples shape:  (100, 28, 28, 3)
No description has been provided for this image

Question 5: Causal Transformer: Text¶

Now lets consider text! You are probably already fimilar with autoregressive transformers for text, now more commonly known as Large Language Modesl (LLMs). We will now implement a simplified version.

We will be detailing with a small poetry dataset. See some of the data below.

In [2]:
data = visualize_q5_data()
data_dir:  /home/nghiaph/workspace/deepul/homeworks/hw1/data
Sample 1
E.E. Cummings, [as freedom is a breakfastfood] from Complete Poems 1904-1962, edited by George J. Firmage. Copyright 1926, 1954, 1991 by the Trustees for the E.E. Cummings Trust. Copyright  1985 by George James Firmage. Reprinted with the permission of Liveright Publishing Corporation.
--------------------------------------------------------------------------------

Sample 2
The moon has left the sky, love,
The stars are hiding now,
And frowning on the world, love,
Night bares her sable brow.

The snow is on the ground, love,
And cold and keen the air is.
Im singing here to you, love;
Youre dreaming there in Paris.

But this is Natures law, love,
Though just it may not seem,
That men should wake to sing, love;
While maidens sleep and dream.

Them care may not molest, love,
Nor stir them from their slumbers,
Though midnight find the swain, love.
Still halting oer his numbers.

I watch the rosy dawn, love,
Come stealing up the east,
While all things round rejoice, love,
That Night her reign has ceased.

The lark will soon be heard, love,
And on his way be winging;
When Natures poets, wake, love,
Why should a man be singing?
--------------------------------------------------------------------------------

Sample 3
My sweetest Lesbia, let us live and love,
And though the sager sort our deeds reprove,
Let us not weigh them. Heavens great lamps do dive
Into their west, and straight again revive,
But soon as once set is our little light,
Then must we sleep one ever-during night.

If all would lead their lives in love like me,
Then bloody swords and armor should not be;
No drum nor trumpet peaceful sleeps should move,
Unless alarm came from the camp of love.
But fools do live, and waste their little light,
And seek with pain their ever-during night.

When timely death my life and fortune ends,
Let not my hearse be vexed with mourning friends,
But let all lovers, rich in triumph, come
And with sweet pastimes grace my happy tomb;
And Lesbia, close up thou my little light,
And crown with love my ever-during night.
--------------------------------------------------------------------------------

Sample 4
When, in disgrace with fortune and mens eyes,
I all alone beweep my outcast state,
And trouble deaf heaven with my bootless cries,
And look upon myself and curse my fate,
Wishing me like to one more rich in hope,
Featured like him, like him with friends possessed,
Desiring this mans art and that mans scope,
With what I most enjoy contented least;
Yet in these thoughts myself almost despising,
Haply I think on thee, and then my state,
(Like to the lark at break of day arising
From sullen earth) sings hymns at heavens gate;
       For thy sweet love remembered such wealth brings
       That then I scorn to change my state with kings.
--------------------------------------------------------------------------------

Part (a) Modeling Text¶

Train a transformer on the poetry dataset.

Data Preprocessing:

  • We will use a simple method to tokenize the data. We will convert each unique character into a token. (Current LLMs use more sophisticated tokenizers, most commonly, byte-pair encoding)
  • Previously we have leveraged a <bos> as part of the model, just like iGPT. For text, we may not always sample a sequence that starts at the beginning. Instead, we will add the <bos> token to the beginning of every sequence in the dataset, and remove the <bos> token from the model.
  • Another problem is that the model must know when to stop sampling. This is done by appending an <eos>, or end of sequence token at the end of every sequence in the dataset.
  • We can now convert the sequence into subsequences of size context_length, for training!

We recommend the following hyperparameters:

  • Sequence length: 128
  • 5 epochs

You will provide these deliverables

  1. Over the course of training, record the average negative log-likelihood (nats / dim) of the training data (per minibatch) and test data (for your entire test set). Code is provided that automatically plots the training curves.
  2. Report the final test set performance of your final model
  3. Provide 5 unconditional samples of 128 characters showcasing the model text generation capabilities (text samples should stop after <eos>. Text after <eos> can be removed in post processing)
In [15]:
import torch
from torch.utils.data import Dataset

class Tokenizer:
    def __init__(self, texts):
        # Create a set of all unique characters across all texts
        all_chars = set()
        for text in texts:
            all_chars.update(text)
        
        # Sort characters for consistent mapping
        all_chars = sorted(all_chars)
        
        self.char_to_id = {char: i + 2 for i, char in enumerate(all_chars)}
        self.id_to_char = {i + 2: char for i, char in enumerate(all_chars)}
        
        self.bos_token = 0
        self.eos_token = 1
        
        self.char_to_id['<bos>'] = self.bos_token
        self.char_to_id['<eos>'] = self.eos_token
        self.id_to_char[self.bos_token] = '<bos>'
        self.id_to_char[self.eos_token] = '<eos>'
        
        self.vocab_size = len(self.char_to_id)
    
    def encode(self, text):
        tokens = [self.char_to_id[char] for char in text]
        tokens.insert(0, self.bos_token)
        tokens.append(self.eos_token)
        return torch.tensor(tokens)
    
    def decode(self, tokens):
        chars = [self.id_to_char[token] for token in tokens if token != self.bos_token and token != self.eos_token]
        # remove the special tokens
        chars = [char for char in chars if char != '<bos>' and char != '<eos>']
        return ''.join(chars)

class TextData(Dataset): 
    def __init__(self, texts, tokenizer, sequence_length):
        self.tokenizer = tokenizer
        self.sequence_length = sequence_length
        
        # Tokenize all texts with BOS and EOS tokens
        self.sequences = []
        for text in texts:
            # Encode the text (this adds BOS and EOS)
            tokens = tokenizer.encode(text)
            stride = 1
            if len(tokens) > sequence_length:
                for i in range(0, len(tokens) - sequence_length + 1, stride):
                    self.sequences.append(tokens[i:i + sequence_length])
            # else:
            #     # Drop the shorter sequences
            #     padded = torch.full((sequence_length,), self.tokenizer.eos_token, dtype=tokens.dtype)
            #     padded[:len(tokens)] = tokens
            #     self.sequences.append(padded)

    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, index):
        return self.sequences[index]
    
from torch.utils.data import Dataset, DataLoader

# Then in your function:
def create_text_data_loader(texts, sequence_length, tokenizer, batch_size):
    dataset = TextData(texts, tokenizer, sequence_length)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)
In [19]:
# text model architecture version

class ScaledDotProductAttention(nn.Module):
    def __init__(self, dropout=0.0):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        """
        q: (batch_size, n_heads, seq_len, head_size)
        k: (batch_size, n_heads, seq_len, head_size)
        v: (batch_size, n_heads, seq_len, head_size)
        """
        d_k = q.shape[-1]
        scores = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(d_k)  # (batch_size, n_heads, seq_len, seq_len)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attention_weights = torch.softmax(scores, dim=-1)  # (batch_size, n_heads, seq_len, seq_len)
        attention_weights = self.dropout(attention_weights)  # (batch_size, n_heads, seq_len, seq_len)

        output = torch.matmul(attention_weights, v)  # (batch_size, n_heads, seq_len, head_size)
        return output

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.0, cache=False):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_size = d_model // n_heads
        self.use_cache = cache

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        self.out = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.attention = ScaledDotProductAttention(dropout=dropout)
        self.cached_k = None
        self.cached_v = None

    def split_heads(self, x):
        """
        x: (batch_size, seq_len, d_model)
        """
        batch_size, seq_len, d_model = x.shape
        return x.view(batch_size, seq_len, self.n_heads, self.head_size).transpose(1, 2)  # (batch_size, n_heads, seq_len, head_size)
    
    def combine_heads(self, x):
        """
        x: (batch_size, n_heads, seq_len, head_size)
        """
        batch_size, n_heads, seq_len, head_size = x.shape
        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)  # (batch_size, seq_len, d_model)
    
    def forward(self, x, mask=None, use_cache=False, past_key_values=None):
        batch_size, seq_len, d_model = x.shape
        if past_key_values is not None:
            self.cached_k, self.cached_v = past_key_values
        q = self.W_q(x)  # (batch_size, seq_len, d_model)
        k = self.W_k(x)
        v = self.W_v(x)

        q = self.split_heads(q)  # (batch_size, n_heads, seq_len, head_size)
        k = self.split_heads(k)
        v = self.split_heads(v)
        
        # Use KV cache if enabled
        if use_cache and self.cached_k is not None and self.cached_v is not None:
            # Concatenate current k, v with cached k, v
            k = torch.cat([self.cached_k, k], dim=2)
            v = torch.cat([self.cached_v, v], dim=2)


            self.cached_k = k
            self.cached_v = v
        
        # Create causal mask if needed
        if mask is None:
            # If using cache, adjust mask to account for the full sequence length
            full_seq_len = k.size(2)
            # For cached version, we need to adjust the mask to allow attention to all past tokens
            if use_cache and self.cached_k is not None:
                # Create a mask where current tokens can attend to all previous tokens
                # Current sequence position is at seq_len
                seq_position = seq_len
                # Create a mask that allows each token to see itself and all previous tokens
                mask = torch.ones(seq_len, full_seq_len).to(x.device)
                # Make it causal by setting future positions to 0
                mask[:, seq_position:] = 0
            else:
                # Standard causal mask for the full sequence
                mask = torch.tril(torch.ones(full_seq_len, full_seq_len)).to(x.device)
            
            mask = mask.unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len, seq_len)

        # Use the attention module directly
        output = self.attention(q, k, v, mask)  # (batch_size, n_heads, seq_len, head_size)
        
        # Combine heads
        output = self.combine_heads(output)  # (batch_size, seq_len, d_model)
        past_key_values = (k, v)
        if use_cache:
            return self.dropout(self.out(output)) , past_key_values
        else:
            return self.dropout(self.out(output))
    
    def clear_cache(self):
        self.cached_k = None
        self.cached_v = None
    
class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1, use_cache=False):
        super().__init__()
        self.masked_mha = MultiHeadAttention(d_model, n_heads, dropout, cache=use_cache)
        self.layer_norm1 = nn.LayerNorm(d_model)
        
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),  
            nn.Linear(4 * d_model, d_model),
            nn.Dropout(dropout)
        )
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, use_cache=False, past_key_values=None):

        # Self-attention with residual connection and layer normalization
        residual = x
        x = self.layer_norm1(x)  # Pre-norm architecture
        if use_cache and past_key_values is not None:
            x, past_key_values = self.masked_mha(x, use_cache=use_cache, past_key_values=past_key_values)
        else:
            x = self.masked_mha(x)
            
        x = residual + x  # Residual connection
        
        # Feed forward with residual connection and layer normalization
        residual = x
        x = self.layer_norm2(x)  # Pre-norm architecture
        x = self.feed_forward(x)
        x = residual + x  # Residual connection
        if use_cache:
            return x , past_key_values
        else:
            return x
    
    def clear_cache(self):
        self.masked_mha.clear_cache()

class iGPT(nn.Module):
    def __init__(self, vocab_size, context_length, d_model, n_heads, n_layers, dropout=0.1, use_cache=False):
        super().__init__()
        self.vocab_size = vocab_size
        self.context_length = context_length
        self.d_model = d_model
        self.n_heads = n_heads  
        self.n_layers = n_layers
        self.dropout = dropout
        self.use_cache = use_cache
        
        # Token embedding
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        
        # Positional embedding (learned, as per iGPT specs)
        self.position_embedding = nn.Embedding(context_length, d_model)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
        
        # Stack of decoder layers
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, n_heads, dropout, use_cache=use_cache) 
            for _ in range(n_layers)
        ])
        
        # Final layer norm
        self.layer_norm = nn.LayerNorm(d_model)
        
        # Output projection
        self.output_projection = nn.Linear(d_model, vocab_size)

    def forward(self, x, past_key_values=None, use_cache=False):
        # x shape: (batch_size, seq_len)
        batch_size, seq_len = x.shape
        device = x.device
        
        # Create position indices
        positions = torch.arange(0, seq_len, dtype=torch.long, device=device).unsqueeze(0).expand(batch_size, -1)
        
        # Get embeddings
        token_emb = self.token_embedding(x)  # (batch_size, seq_len, d_model)
        pos_emb = self.position_embedding(positions)  # (batch_size, seq_len, d_model)
        
        # Combine embeddings
        x = token_emb + pos_emb  # (batch_size, seq_len, d_model)
        x = self.dropout(x)
        
        # Apply decoder layers
        past_key_values = None
        for layer in self.decoder_layers:
            if use_cache:
                x, past_key_values = layer(x, use_cache=use_cache, past_key_values=past_key_values)
            else:
                x = layer(x)
        
        # Apply final layer norm
        x = self.layer_norm(x)  # (batch_size, seq_len, d_model)
        
        # Project to vocabulary
        logits = self.output_projection(x)  # (batch_size, seq_len, vocab_size)
        if use_cache:
            return logits, past_key_values
        else:
            return logits
    
    def clear_cache(self):
        for layer in self.decoder_layers:
            layer.clear_cache()
In [20]:
import math

def create_dataset(data, image_shape, batch_size):
    """
    Converts image data to token sequences and creates PyTorch DataLoader.
    
    Args:
        data: A (n_samples, H, W, C) uint8 numpy array of images
        image_shape: (H, W, C) tuple specifying image dimensions
        batch_size: Batch size for DataLoader
        
    Returns:
        DataLoader object with tokenized image sequences
    """
    H, W, C = image_shape
    
    # Convert RGB pixels to single tokens (4 values per channel = 64 possible values)
    # Shape: (n_samples, H, W, C) -> (n_samples, H, W)
    if C == 3:
        # Convert RGB values to a single token: r*16 + g*4 + b
        # Each channel has values in {0,1,2,3}, so we can encode as a single number 0-63
        data_tokens = (data[:,:,:,0] * 16 + data[:,:,:,1] * 4 + data[:,:,:,2])
    else:
        # For grayscale, just use the values directly
        data_tokens = data.reshape(-1, H, W)
    
    # Flatten spatial dimensions to create sequences
    # Shape: (n_samples, H, W) -> (n_samples, H*W)
    data_flat = data_tokens.reshape(-1, H * W)
    
    # Convert to PyTorch tensors
    dataset = torch.utils.data.TensorDataset(torch.tensor(data_flat, dtype=torch.long))
    
    # Create data loader
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

def evaluate_model(model, data_loader, sequence_length, vocab_size, device):
    """
    Evaluates model performance on a dataset.
    
    Args:
        model: The iGPT model
        data_loader: DataLoader containing tokenized images
        sequence_length: Length of token sequences including <bos>
        vocab_size: Size of vocabulary
        device: Device to run evaluation on
        
    Returns:
        Average loss (negative log-likelihood) per dimension
    """
    model.eval()
    total_loss = 0
    total_samples = 0
    
    with torch.no_grad():
        for data in data_loader:
            data = data.to(device)  # Shape: (batch_size, sequence_length-1)
            batch_size = data.size(0)
            
            input_seq = data[:, :-1]
            targets = data[:, 1:]
            
            # Forward pass
            logits = model(input_seq)  # Remove last position's prediction
            
            # Compute loss
            loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1), reduction='sum')
            
            total_loss += loss.item()
            total_samples += batch_size * (sequence_length - 1)
    
    return total_loss / total_samples



def train_igpt(model, train_loader, test_loader, sequence_length, vocab_size, 
               device, num_epochs, learning_rate):
    """
    Trains the iGPT model.
    
    Args:
        model: The iGPT model to train
        train_loader: DataLoader for training data
        test_loader: DataLoader for test data
        sequence_length: Length of token sequences including <bos>
        vocab_size: Size of vocabulary
        device: Device to train on
        num_epochs: Number of training epochs
        learning_rate: Initial learning rate
        
    Returns:
        train_losses: Array of training losses per minibatch
        test_losses: Array of test losses per epoch
    """
    # Initialize optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    # Learning rate scheduler with warmup and cosine decay
    warmup_steps = 100
    total_steps = len(train_loader) * num_epochs
    
    def lr_lambda(step):
        if step < warmup_steps:
            return step / warmup_steps
        else:
            decay_ratio = (step - warmup_steps) / (total_steps - warmup_steps)
            return 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    
    # Initialize arrays to store losses
    train_losses = []
    test_losses = [evaluate_model(model, test_loader, sequence_length, vocab_size, device)]
    
    # Training loop
    for epoch in range(num_epochs):
        model.train()
        epoch_losses = []
        batch_idx = 0
        for data in train_loader:
            batch_idx += 1
            data = data.to(device)  # Shape: (batch_size, sequence_length)
            
            # Shape: (batch_size, sequence_length-1)
            input_seq = data[:, :-1]
            targets = data[:, 1:]
            
            # Forward pass
            logits = model(input_seq)
            targets = targets.to(device)
            
            loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1))
            # print(f"loss: {loss.item():.4f}")
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            # Record loss
            train_losses.append(loss.item())
            
            if batch_idx % 50 == 0:
                print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")
        
        # Evaluate on test set after each epoch
        test_loss = evaluate_model(model, test_loader, sequence_length, vocab_size, device)
        test_losses.append(test_loss)
        print(f"Epoch {epoch+1}/{num_epochs} completed. Test Loss: {test_loss:.4f}")
    return np.array(train_losses), np.array(test_losses)
In [21]:
def generate_text_samples(model, tokenizer, max_length, device, num_samples=10, temperature=1.0, use_cache=False):
    """
    Generates text samples from the trained model.
    
    Args:
        model: The trained language model
        tokenizer: The tokenizer used to encode/decode text
        max_length: Maximum length of the generated sequence (including BOS/EOS)
        device: Device to run generation on
        num_samples: Number of samples to generate
        temperature: Controls randomness (lower = more deterministic)
        use_cache: Whether to use caching for faster sampling
        
    Returns:
        List of generated text samples and a list of generation times
    """
    model.eval()
    samples = []
    import time
    time_list = []
    
    with torch.no_grad():
        for _ in range(num_samples):
            start_time = time.time()
            
            # Start with just the BOS token
            current_seq = torch.tensor([[tokenizer.bos_token]], dtype=torch.long, device=device)
            
            # Cache for key-value pairs if using caching
            past_key_values = None
            
            # Autoregressive generation - one token at a time
            for _ in range(max_length - 1):  # -1 because we already have BOS
                if use_cache and past_key_values is not None:
                    # Only need to process the new token with cached key-values
                    logits, past_key_values = model(
                        current_seq[:, -1:],
                        past_key_values=past_key_values, 
                        use_cache=True
                    )
                    logits = logits[:, -1, :]  # Get prediction for current position
                else:
                    # Process the entire sequence
                    if use_cache:
                        logits, past_key_values = model(current_seq, use_cache=True)
                        logits = logits[:, -1, :]  # Get prediction for current position
                    else:
                        logits = model(current_seq)
                        logits = logits[:, -1, :]  # Get prediction for last position
                
                # Apply temperature
                if temperature != 1.0:
                    logits = logits / temperature
                
                # Sample from the probability distribution
                probs = F.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, 1)
                
                # Append new token to sequence
                current_seq = torch.cat([current_seq, next_token], dim=1)
                
                # Stop if EOS token is generated
                if next_token.item() == tokenizer.eos_token:
                    break
            
            # Decode the generated sequence
            generated_tokens = current_seq[0].cpu().tolist()
            generated_text = tokenizer.decode(generated_tokens)
            samples.append(generated_text)
            
            end_time = time.time()
            time_list.append(end_time - start_time)
    
    return samples, np.array(time_list)
In [22]:
import torch.utils.data as data
def q5_a(train_text, test_text):
  """
  train_text: list[str] Train text sequences.
  test_text: list[str] Test text sequences.

  Returns
  - a (# of training iterations,) numpy array of train_losses evaluated every minibatch
  - a (# of epochs + 1,) numpy array of test_losses evaluated once at initialization and after each epoch
  - a list of 5 (str), 5 generated samples from the model.
  """
  sequence_length = 128
  epochs = 5
  learning_rate = 1e-3
  d_model = 128  
  n_heads = 4    
  n_layers = 6   
  batch_size = 1024
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  
  tokenizer = Tokenizer(train_text)
  vocab_size = tokenizer.vocab_size

  train_loader = create_text_data_loader(train_text, sequence_length, tokenizer, batch_size)
  test_loader = create_text_data_loader(test_text, sequence_length, tokenizer, batch_size)
  
  
  model = iGPT(vocab_size, sequence_length, d_model, n_heads, n_layers).to(device)
  train_losses, test_losses = train_igpt(model, train_loader, test_loader, sequence_length, vocab_size, device, epochs, learning_rate)
  
  text_samples, _ = generate_text_samples(model, tokenizer, sequence_length, device, num_samples=5)
  return train_losses, test_losses, text_samples

Results¶

Once you've implemented q5_a, execute the cells below to visualize and save your results

In [23]:
q5a_save_results(q5_a)
data_dir:  /home/nghiaph/workspace/deepul/homeworks/hw1/data
Epoch 1/5, Batch 50/477, Loss: 2.8145
Epoch 1/5, Batch 100/477, Loss: 2.4764
Epoch 1/5, Batch 150/477, Loss: 2.3922
Epoch 1/5, Batch 200/477, Loss: 2.3275
Epoch 1/5, Batch 250/477, Loss: 2.2068
Epoch 1/5, Batch 300/477, Loss: 2.0961
Epoch 1/5, Batch 350/477, Loss: 2.0128
Epoch 1/5, Batch 400/477, Loss: 1.9417
Epoch 1/5, Batch 450/477, Loss: 1.8972
Epoch 1/5 completed. Test Loss: 1.8999
Epoch 2/5, Batch 50/477, Loss: 1.8257
Epoch 2/5, Batch 100/477, Loss: 1.7714
Epoch 2/5, Batch 150/477, Loss: 1.7343
Epoch 2/5, Batch 200/477, Loss: 1.7065
Epoch 2/5, Batch 250/477, Loss: 1.6745
Epoch 2/5, Batch 300/477, Loss: 1.6477
Epoch 2/5, Batch 350/477, Loss: 1.6189
Epoch 2/5, Batch 400/477, Loss: 1.5987
Epoch 2/5, Batch 450/477, Loss: 1.5895
Epoch 2/5 completed. Test Loss: 1.6667
Epoch 3/5, Batch 50/477, Loss: 1.5469
Epoch 3/5, Batch 100/477, Loss: 1.5355
Epoch 3/5, Batch 150/477, Loss: 1.5236
Epoch 3/5, Batch 200/477, Loss: 1.5015
Epoch 3/5, Batch 250/477, Loss: 1.5034
Epoch 3/5, Batch 300/477, Loss: 1.4954
Epoch 3/5, Batch 350/477, Loss: 1.4756
Epoch 3/5, Batch 400/477, Loss: 1.4742
Epoch 3/5, Batch 450/477, Loss: 1.4636
Epoch 3/5 completed. Test Loss: 1.6080
Epoch 4/5, Batch 50/477, Loss: 1.4474
Epoch 4/5, Batch 100/477, Loss: 1.4380
Epoch 4/5, Batch 150/477, Loss: 1.4350
Epoch 4/5, Batch 200/477, Loss: 1.4301
Epoch 4/5, Batch 250/477, Loss: 1.4225
Epoch 4/5, Batch 300/477, Loss: 1.4228
Epoch 4/5, Batch 350/477, Loss: 1.4133
Epoch 4/5, Batch 400/477, Loss: 1.4063
Epoch 4/5, Batch 450/477, Loss: 1.3998
Epoch 4/5 completed. Test Loss: 1.6020
Epoch 5/5, Batch 50/477, Loss: 1.3968
Epoch 5/5, Batch 100/477, Loss: 1.4050
Epoch 5/5, Batch 150/477, Loss: 1.3925
Epoch 5/5, Batch 200/477, Loss: 1.3913
Epoch 5/5, Batch 250/477, Loss: 1.4003
Epoch 5/5, Batch 300/477, Loss: 1.3889
Epoch 5/5, Batch 350/477, Loss: 1.3907
Epoch 5/5, Batch 400/477, Loss: 1.3959
Epoch 5/5, Batch 450/477, Loss: 1.4015
Epoch 5/5 completed. Test Loss: 1.6017
Final Test Loss: 1.6017
No description has been provided for this image
Sample 1
Doe gods "Hen the Spype and "yde with eyes
Not little kins hise outbing through the bair,
And with your bronze younge tongue 

Sample 2
Those dress which forth did golden she
Of heaving in white do:

There she's sprungth,
And thee but them burn.
Sleeps on it

Sample 3
Nymphings love's the long Time, and leaves store;
where your stouth with the floor story-flocks
           To still was forth

Sample 4

He heaven thou be small at like a certain,
With the risely plant that wait
With they death may the Chokall straight-19Kobe,

Sample 5
A from Poetrynel's beast
And so song: from grown and his grace,
Free, which making the eyes content,
Leaves of right, what b

No description has been provided for this image

Question 6: Causal Transformer: Multimodal¶

So far, we have been dealing only with autoregressive generation of a single modality. Now we will train a model that operates on multiple modalities!

We will use the text labeled colored MNIST dataset, which has a text description of the MNIST image. Run the cell below to visualize the data along with the text annotation. This is the Colored MNIST v2 dataset, which also comes with these text labels.

In [29]:
visualize_q6_data()
data_dir:  /home/nghiaph/workspace/deepul/homeworks/hw1/data
No description has been provided for this image

Part (a) Multimodal Text and Image Generation¶

Implement and train an autoregressive (AR) model capable of handling both text and image data. The model should be designed to process sequences composed of concatenated text and image tokens in both orders (text followed by images and images followed by text). Additionally, the model should be capable of generating unconditional text and image samples.

Data Preprocessing:

  • Text Tokens: Map each unique word in the text data to a unique token. (Note that all text descriptions contain the exact same amount of words. This simplifies text processing, as you won't have to deal with sequences of different lengths as in Question 5)
  • Image Tokens: Quantize the image data into tokens using the VQVAE tokenizer from Problem 4.
  • In this problem, we have 2 modalities. Introduce an <end of text> token and an <end of image> token. After seeing such a token, the model should switch to sampling the next modality.
  • Formulate batches as sequences of concat([<end of image>, text_tokens, <end of text>, image_tokens]) and concat([<end of text>, image_tokens, <end of image>, text_tokens]). With a 50/50 split between each ordering.

Inference:

  • During inference, we cannot mix modality tokens. During sampling we can restrict the logits to only be within the relevant modality.
  • After <end of image>, only allow the model to sample text tokens (including <end of text>)
  • After <end of text>, only allow the model to sample image tokens (including <end of image>)
  • At the very start (conditioned on the <bos> token, only allow the model to sample one of (<end of image> or <end of text>))
  • As the model may not always correctly sample the <end of image> token before the image ends, you may add a rule to force the model to always sample the correct number of image tokens (49 tokens).

You can use the same hyperparameters as in 4(b) (but of course, feel free to tune your model to achieve better performance)

You will provide these deliverables

  1. Over the course of training, record the average negative log-likelihood (nats / dim) of the training data (per minibatch) and test data (for your entire test set). Code is provided that automatically plots the training curves.
  2. Report the final test set performance of your final model
  3. 9 conditional samples based on provided text.
  4. 9 conditional samples based on provided images.
  5. 9 unconditional samples showcasing the model's capability in generating standalone text and images.
In [30]:
# multiple model architecture version

class ScaledDotProductAttention(nn.Module):
    def __init__(self, dropout=0.0):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        """
        q: (batch_size, n_heads, seq_len, head_size)
        k: (batch_size, n_heads, seq_len, head_size)
        v: (batch_size, n_heads, seq_len, head_size)
        """
        d_k = q.shape[-1]
        scores = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(d_k)  # (batch_size, n_heads, seq_len, seq_len)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attention_weights = torch.softmax(scores, dim=-1)  # (batch_size, n_heads, seq_len, seq_len)
        attention_weights = self.dropout(attention_weights)  # (batch_size, n_heads, seq_len, seq_len)

        output = torch.matmul(attention_weights, v)  # (batch_size, n_heads, seq_len, head_size)
        return output

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.0, cache=False):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_size = d_model // n_heads
        self.use_cache = cache

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        self.out = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.attention = ScaledDotProductAttention(dropout=dropout)
        self.cached_k = None
        self.cached_v = None

    def split_heads(self, x):
        """
        x: (batch_size, seq_len, d_model)
        """
        batch_size, seq_len, d_model = x.shape
        return x.view(batch_size, seq_len, self.n_heads, self.head_size).transpose(1, 2)  # (batch_size, n_heads, seq_len, head_size)
    
    def combine_heads(self, x):
        """
        x: (batch_size, n_heads, seq_len, head_size)
        """
        batch_size, n_heads, seq_len, head_size = x.shape
        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)  # (batch_size, seq_len, d_model)
    
    def forward(self, x, mask=None, use_cache=False, past_key_values=None):
        batch_size, seq_len, d_model = x.shape
        if past_key_values is not None:
            self.cached_k, self.cached_v = past_key_values
        q = self.W_q(x)  # (batch_size, seq_len, d_model)
        k = self.W_k(x)
        v = self.W_v(x)

        q = self.split_heads(q)  # (batch_size, n_heads, seq_len, head_size)
        k = self.split_heads(k)
        v = self.split_heads(v)
        
        # Use KV cache if enabled
        if use_cache and self.cached_k is not None and self.cached_v is not None:
            # Concatenate current k, v with cached k, v
            k = torch.cat([self.cached_k, k], dim=2)
            v = torch.cat([self.cached_v, v], dim=2)


            self.cached_k = k
            self.cached_v = v
        
        # Create causal mask if needed
        if mask is None:
            # If using cache, adjust mask to account for the full sequence length
            full_seq_len = k.size(2)
            # For cached version, we need to adjust the mask to allow attention to all past tokens
            if use_cache and self.cached_k is not None:
                # Create a mask where current tokens can attend to all previous tokens
                # Current sequence position is at seq_len
                seq_position = seq_len
                # Create a mask that allows each token to see itself and all previous tokens
                mask = torch.ones(seq_len, full_seq_len).to(x.device)
                # Make it causal by setting future positions to 0
                mask[:, seq_position:] = 0
            else:
                # Standard causal mask for the full sequence
                mask = torch.tril(torch.ones(full_seq_len, full_seq_len)).to(x.device)
            
            mask = mask.unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len, seq_len)

        # Use the attention module directly
        output = self.attention(q, k, v, mask)  # (batch_size, n_heads, seq_len, head_size)
        
        # Combine heads
        output = self.combine_heads(output)  # (batch_size, seq_len, d_model)
        past_key_values = (k, v)
        if use_cache:
            return self.dropout(self.out(output)) , past_key_values
        else:
            return self.dropout(self.out(output))
    
    def clear_cache(self):
        self.cached_k = None
        self.cached_v = None
    
class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1, use_cache=False):
        super().__init__()
        self.masked_mha = MultiHeadAttention(d_model, n_heads, dropout, cache=use_cache)
        self.layer_norm1 = nn.LayerNorm(d_model)
        
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),  
            nn.Linear(4 * d_model, d_model),
            nn.Dropout(dropout)
        )
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, use_cache=False, past_key_values=None):

        # Self-attention with residual connection and layer normalization
        residual = x
        x = self.layer_norm1(x)  # Pre-norm architecture
        if use_cache and past_key_values is not None:
            x, past_key_values = self.masked_mha(x, use_cache=use_cache, past_key_values=past_key_values)
        else:
            x = self.masked_mha(x)
            
        x = residual + x  # Residual connection
        
        # Feed forward with residual connection and layer normalization
        residual = x
        x = self.layer_norm2(x)  # Pre-norm architecture
        x = self.feed_forward(x)
        x = residual + x  # Residual connection
        if use_cache:
            return x , past_key_values
        else:
            return x
    
    def clear_cache(self):
        self.masked_mha.clear_cache()

class iGPT(nn.Module):
    def __init__(self, vocab_size, context_length, d_model, n_heads, n_layers, dropout=0.1, use_cache=False):
        super().__init__()
        self.vocab_size = vocab_size
        self.context_length = context_length
        self.d_model = d_model
        self.n_heads = n_heads  
        self.n_layers = n_layers
        self.dropout = dropout
        self.use_cache = use_cache
        
        # Token embedding
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        
        # Positional embedding (learned, as per iGPT specs)
        self.position_embedding = nn.Embedding(context_length, d_model)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
        
        # Stack of decoder layers
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, n_heads, dropout, use_cache=use_cache) 
            for _ in range(n_layers)
        ])
        
        # Final layer norm
        self.layer_norm = nn.LayerNorm(d_model)
        
        # Output projection
        self.output_projection = nn.Linear(d_model, vocab_size)

    def forward(self, x, past_key_values=None, use_cache=False):
        # x shape: (batch_size, seq_len)
        batch_size, seq_len = x.shape
        device = x.device
        
        # Create position indices
        positions = torch.arange(0, seq_len, dtype=torch.long, device=device).unsqueeze(0).expand(batch_size, -1)
        
        # Get embeddings
        token_emb = self.token_embedding(x)  # (batch_size, seq_len, d_model)
        pos_emb = self.position_embedding(positions)  # (batch_size, seq_len, d_model)
        
        # Combine embeddings
        x = token_emb + pos_emb  # (batch_size, seq_len, d_model)
        x = self.dropout(x)
        
        # Apply decoder layers
        past_key_values = None
        for layer in self.decoder_layers:
            if use_cache:
                x, past_key_values = layer(x, use_cache=use_cache, past_key_values=past_key_values)
            else:
                x = layer(x)
        
        # Apply final layer norm
        x = self.layer_norm(x)  # (batch_size, seq_len, d_model)
        
        # Project to vocabulary
        logits = self.output_projection(x)  # (batch_size, seq_len, vocab_size)
        if use_cache:
            return logits, past_key_values
        else:
            return logits
    
    def clear_cache(self):
        for layer in self.decoder_layers:
            layer.clear_cache()
In [31]:
import math


def evaluate_model(model, data_loader, sequence_length, vocab_size, device):
    """
    Evaluates model performance on a dataset.
    
    Args:
        model: The iGPT model
        data_loader: DataLoader containing tokenized sequences (already includes BOS token)
        sequence_length: Length of token sequences including <bos>
        vocab_size: Size of vocabulary
        device: Device to run evaluation on
        
    Returns:
        Average loss (negative log-likelihood) per dimension
    """
    model.eval()
    total_loss = 0
    total_samples = 0
    
    with torch.no_grad():
        for data in data_loader:
            data = data.to(device)  # Shape: (batch_size, sequence_length)
            batch_size = data.size(0)
            
            # Data already includes BOS token at the beginning
            # Create input sequence (all tokens except the last one)
            input_seq = data[:, :-1]  # Shape: (batch_size, sequence_length-1)
            
            # Create targets (all tokens except the first BOS token)
            targets = data[:, 1:]  # Shape: (batch_size, sequence_length-1)
            
            # Forward pass
            logits = model(input_seq)  # Shape: (batch_size, sequence_length-1, vocab_size)
            
            # Compute loss
            loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1), reduction='sum')
            
            total_loss += loss.item()
            total_samples += batch_size * (sequence_length - 1)
    
    return total_loss / total_samples



def train_igpt(model, train_loader, test_loader, sequence_length, vocab_size, 
               device, num_epochs, learning_rate):
    """
    Trains the iGPT model.
    
    Args:
        model: The iGPT model to train
        train_loader: DataLoader for training data (already includes BOS token)
        test_loader: DataLoader for test data (already includes BOS token)
        sequence_length: Length of token sequences including <bos>
        vocab_size: Size of vocabulary
        device: Device to train on
        num_epochs: Number of training epochs
        learning_rate: Initial learning rate
        
    Returns:
        train_losses: Array of training losses per minibatch
        test_losses: Array of test losses per epoch
    """
    # Initialize optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    # Learning rate scheduler with warmup and cosine decay
    warmup_steps = 1000
    total_steps = len(train_loader) * num_epochs
    
    def lr_lambda(step):
        if step < warmup_steps:
            return step / warmup_steps
        else:
            decay_ratio = (step - warmup_steps) / (total_steps - warmup_steps)
            return 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    
    # Initialize arrays to store losses
    train_losses = []
    test_losses = [evaluate_model(model, test_loader, sequence_length, vocab_size, device)]
    
    # Training loop
    for epoch in range(num_epochs):
        model.train()
        epoch_losses = []
        
        for (batch_idx,data) in enumerate(train_loader):
            data = data.to(device)  # Shape: (batch_size, sequence_length)
            batch_size = data.size(0)
            
            # Data already includes BOS token at the beginning
            # Create input sequence (all tokens except the last one)
            input_seq = data[:, :-1]  # Shape: (batch_size, sequence_length-1)
            
            # Create targets (all tokens except the first BOS token)
            targets = data[:, 1:]  # Shape: (batch_size, sequence_length-1)
            
            # Forward pass
            logits = model(input_seq)  # Shape: (batch_size, sequence_length-1, vocab_size)
            
            # Compute loss
            loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1))
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            # Record loss
            train_losses.append(loss.item())
            epoch_losses.append(loss.item())
            
            if batch_idx % 50 == 0:
                print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")
        
        # Evaluate on test set after each epoch
        test_loss = evaluate_model(model, test_loader, sequence_length, vocab_size, device)
        test_losses.append(test_loss)
        print(f"Epoch {epoch+1}/{num_epochs} completed. Test Loss: {test_loss:.4f}")
    
    return np.array(train_losses), np.array(test_losses)
In [32]:
class Tokenizer:
    def __init__(self, texts, offset):
        self.texts = texts
        self.offset = offset
        self.all_words = set()
        for text in texts:
            self.all_words.update(text.split())
        
        # Convert set to list for consistent ordering
        self.all_words = list(self.all_words)
        self.vocab_size = len(self.all_words)

        # Add special tokens after calculating vocab_size
        # Reserve token 0 for BOS token
        self.bos_token = 0
        self.end_of_text_token = self.vocab_size + self.offset
        self.end_of_image_token = self.vocab_size + 1 + self.offset
        self.all_words.extend(['<end_of_text>', '<end_of_image>'])
        
        # Create mappings with offset applied (starting from 1 to reserve 0 for BOS)
        self.word_to_id = {word: i + 1 + self.offset for i, word in enumerate(self.all_words)}
        self.id_to_word = {i + 1 + self.offset: word for i, word in enumerate(self.all_words)}
        # Add BOS token to mappings
        self.id_to_word[self.bos_token] = '<bos>'

        
    def text_encode(self, text):
        tokens = [self.word_to_id[word] for word in text.split()]
        return torch.tensor(tokens)
        
    def text_decode(self, tokens):
        return ' '.join([self.id_to_word[token] for token in tokens if token != self.end_of_text_token and token != self.bos_token])
    
def create_dataset(images, texts, vqvae, text_tokenizer, batch_size):
    # create a dataset of images and texts
    dataset = []
    bos_token = text_tokenizer.bos_token
    end_of_image_token = text_tokenizer.end_of_image_token
    end_of_text_token = text_tokenizer.end_of_text_token
 
    print(f"Creating dataset from {len(images)} samples...")
    
    # Pre-tokenize all text data at once for efficiency
    print("Pre-tokenizing all text data...")
    all_text_tokens = [text_tokenizer.text_encode(text) for text in texts]
    
    # Batch process images for VQVAE quantization
    print("Batch processing images...")
    batch_size_process = 128
    all_image_tokens = []

    for i in range(0, len(images), batch_size_process):
        batch_end = min(i + batch_size_process, len(images))
        batch_images = images[i:batch_end]
        
        # Process batch of images
        batch_image_tokens = vqvae.quantize(batch_images)
        
        # Flatten each image's tokens and store
        for j in range(batch_image_tokens.shape[0]):
            image_tokens_flat = batch_image_tokens[j].flatten()
            all_image_tokens.append(image_tokens_flat)
        
        if i % (batch_size_process * 1000) == 0:
            print(f"Processed {min(i + batch_size_process, len(images))}/{len(images)} images ({min(i + batch_size_process, len(images))/len(images)*1000:.1f}%)")
    
    # Create special token tensors once
    bos_tensor = torch.tensor([bos_token])
    end_of_image_tensor = torch.tensor([end_of_image_token])
    end_of_text_tensor = torch.tensor([end_of_text_token])
    
    print("Assembling dataset...")
    for idx in range(len(texts)):
        text_tokens = all_text_tokens[idx]
        image_tokens_flat = all_image_tokens[idx]
        
        if idx % 2 == 0:
            # text followed by image
            complete_tokens = torch.cat((bos_tensor, end_of_image_tensor, text_tokens, end_of_text_tensor, image_tokens_flat))
            dataset.append(complete_tokens)
        else:
            # image followed by text
            complete_tokens = torch.cat((bos_tensor, end_of_text_tensor, image_tokens_flat, end_of_image_tensor, text_tokens))
            dataset.append(complete_tokens)
    
    print(f"Dataset creation complete! Total samples: {len(dataset)}")
    print(f"Creating DataLoader with batch_size={batch_size}")
    
    # create dataloader
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)
In [33]:
def generate_conditional_samples_from_text(model, text_tokenizer, vqvae, text_prompts, device, max_length=58):
    """
    Generate images conditioned on text prompts.
    
    Args:
        model: Trained iGPT model
        text_tokenizer: Text tokenizer
        vqvae: VQVAE model for decoding image tokens
        text_prompts: List of text strings to condition on
        device: Device to run on
        max_length: Maximum sequence length
        
    Returns:
        List of (image, text) tuples
    """
    model.eval()
    samples = []
    
    with torch.no_grad():
        for text_prompt in text_prompts:
            # Start with BOS token and end_of_image token, then text tokens, then end_of_text token
            text_tokens = text_tokenizer.text_encode(text_prompt)
            input_seq = torch.cat([
                torch.tensor([text_tokenizer.bos_token]),
                torch.tensor([text_tokenizer.end_of_image_token]),
                text_tokens,
                torch.tensor([text_tokenizer.end_of_text_token])
            ]).unsqueeze(0).to(device)
            
            # Generate 49 image tokens
            for _ in range(49):  # 7x7 = 49 image tokens
                logits = model(input_seq)
                next_token_logits = logits[0, -1, :]
                
                # Restrict to image tokens only (0 to vqvae.n_embeddings-1)
                mask = torch.zeros_like(next_token_logits)
                mask[:vqvae.n_embeddings] = 1
                next_token_logits = next_token_logits * mask + (1 - mask) * (-1e9)
                
                # Sample next token
                probs = torch.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probs, 1)
                
                # Append to sequence
                input_seq = torch.cat([input_seq, next_token.unsqueeze(0)], dim=1)
            
            # Extract image tokens and decode
            image_tokens = input_seq[0, -49:].cpu().numpy().reshape(7, 7)
            decoded_image = vqvae.decode(image_tokens.reshape(1, 7, 7))[0]
            
            samples.append((decoded_image, text_prompt))
    
    return samples

def generate_conditional_samples_from_image(model, text_tokenizer, vqvae, image_prompts, device, max_length=58):
    """
    Generate text conditioned on image prompts.
    
    Args:
        model: Trained iGPT model
        text_tokenizer: Text tokenizer
        vqvae: VQVAE model for encoding image tokens
        image_prompts: Array of images to condition on
        device: Device to run on
        max_length: Maximum sequence length
        
    Returns:
        List of (image, text) tuples
    """
    model.eval()
    samples = []
    
    with torch.no_grad():
        for image_prompt in image_prompts:
            # Quantize the image
            image_tokens = vqvae.quantize(image_prompt.reshape(1, *image_prompt.shape))[0].flatten()
            
            # Start with BOS token, end_of_text token, image tokens, then end_of_image token
            input_seq = torch.cat([
                torch.tensor([text_tokenizer.bos_token]),
                torch.tensor([text_tokenizer.end_of_text_token]),
                torch.tensor(image_tokens),
                torch.tensor([text_tokenizer.end_of_image_token])
            ]).unsqueeze(0).to(device)
            
            # Generate text tokens (typically 6 words based on the dataset)
            generated_text_tokens = []
            for _ in range(6):  # Assuming 6 words per text description
                logits = model(input_seq)
                next_token_logits = logits[0, -1, :]
                
                # Restrict to text tokens only (excluding special tokens)
                mask = torch.zeros_like(next_token_logits)
                # Text tokens start from vqvae.n_embeddings + 1 (excluding BOS which is 0)
                for word, token_id in text_tokenizer.word_to_id.items():
                    if word not in ['<end_of_text>', '<end_of_image>']:
                        mask[token_id] = 1
                
                next_token_logits = next_token_logits * mask + (1 - mask) * (-1e9)
                
                # Sample next token
                probs = torch.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probs, 1)
                
                generated_text_tokens.append(next_token.item())
                
                # Append to sequence
                input_seq = torch.cat([input_seq, next_token.unsqueeze(0)], dim=1)
            
            # Decode text
            generated_text = text_tokenizer.text_decode(generated_text_tokens)
            
            samples.append((image_prompt, generated_text))
    
    return samples

def generate_unconditional_samples(model, text_tokenizer, vqvae, device, num_samples=9, max_length=58):
    """
    Generate unconditional samples (both text and images).
    
    Args:
        model: Trained iGPT model
        text_tokenizer: Text tokenizer
        vqvae: VQVAE model for decoding
        device: Device to run on
        num_samples: Number of samples to generate
        max_length: Maximum sequence length
        
    Returns:
        List of (image, text) tuples
    """
    model.eval()
    samples = []
    
    with torch.no_grad():
        for _ in range(num_samples):
            # Start with BOS token
            input_seq = torch.tensor([text_tokenizer.bos_token]).unsqueeze(0).to(device)
            
            # First, decide which modality to start with
            logits = model(input_seq)
            next_token_logits = logits[0, -1, :]
            
            # Only allow end_of_image or end_of_text tokens
            mask = torch.zeros_like(next_token_logits)
            mask[text_tokenizer.end_of_image_token] = 1
            mask[text_tokenizer.end_of_text_token] = 1
            next_token_logits = next_token_logits * mask + (1 - mask) * (-1e9)
            
            probs = torch.softmax(next_token_logits, dim=-1)
            modality_token = torch.multinomial(probs, 1)
            input_seq = torch.cat([input_seq, modality_token.unsqueeze(0)], dim=1)
            
            if modality_token.item() == text_tokenizer.end_of_image_token:
                # Generate text first, then image
                
                # Generate 6 text tokens
                for _ in range(6):
                    logits = model(input_seq)
                    next_token_logits = logits[0, -1, :]
                    
                    # Restrict to text tokens
                    mask = torch.zeros_like(next_token_logits)
                    for word, token_id in text_tokenizer.word_to_id.items():
                        if word not in ['<end_of_text>', '<end_of_image>']:
                            mask[token_id] = 1
                    
                    next_token_logits = next_token_logits * mask + (1 - mask) * (-1e9)
                    probs = torch.softmax(next_token_logits, dim=-1)
                    next_token = torch.multinomial(probs, 1)
                    input_seq = torch.cat([input_seq, next_token.unsqueeze(0)], dim=1)
                
                # Add end_of_text token
                input_seq = torch.cat([input_seq, torch.tensor([text_tokenizer.end_of_text_token]).unsqueeze(0).to(device)], dim=1)
                
                # Generate 49 image tokens
                for _ in range(49):
                    logits = model(input_seq)
                    next_token_logits = logits[0, -1, :]
                    
                    # Restrict to image tokens
                    mask = torch.zeros_like(next_token_logits)
                    mask[:vqvae.n_embeddings] = 1
                    next_token_logits = next_token_logits * mask + (1 - mask) * (-1e9)
                    
                    probs = torch.softmax(next_token_logits, dim=-1)
                    next_token = torch.multinomial(probs, 1)
                    input_seq = torch.cat([input_seq, next_token.unsqueeze(0)], dim=1)
                
                # Extract text and image
                text_tokens = input_seq[0, 2:8].cpu().numpy()  # Skip BOS, end_of_image, get 6 text tokens
                image_tokens = input_seq[0, -49:].cpu().numpy().reshape(7, 7)
                
            else:  # end_of_text_token
                # Generate image first, then text
                
                # Generate 49 image tokens
                for _ in range(49):
                    logits = model(input_seq)
                    next_token_logits = logits[0, -1, :]
                    
                    # Restrict to image tokens
                    mask = torch.zeros_like(next_token_logits)
                    mask[:vqvae.n_embeddings] = 1
                    next_token_logits = next_token_logits * mask + (1 - mask) * (-1e9)
                    
                    probs = torch.softmax(next_token_logits, dim=-1)
                    next_token = torch.multinomial(probs, 1)
                    input_seq = torch.cat([input_seq, next_token.unsqueeze(0)], dim=1)
                
                # Add end_of_image token
                input_seq = torch.cat([input_seq, torch.tensor([text_tokenizer.end_of_image_token]).unsqueeze(0).to(device)], dim=1)
                
                # Generate 6 text tokens
                for _ in range(6):
                    logits = model(input_seq)
                    next_token_logits = logits[0, -1, :]
                    
                    # Restrict to text tokens
                    mask = torch.zeros_like(next_token_logits)
                    for word, token_id in text_tokenizer.word_to_id.items():
                        if word not in ['<end_of_text>', '<end_of_image>']:
                            mask[token_id] = 1
                    
                    next_token_logits = next_token_logits * mask + (1 - mask) * (-1e9)
                    probs = torch.softmax(next_token_logits, dim=-1)
                    next_token = torch.multinomial(probs, 1)
                    input_seq = torch.cat([input_seq, next_token.unsqueeze(0)], dim=1)
                
                # Extract image and text
                image_tokens = input_seq[0, 2:51].cpu().numpy().reshape(7, 7)  # Skip BOS, end_of_text, get 49 image tokens
                text_tokens = input_seq[0, -6:].cpu().numpy()  # Get last 6 text tokens
            
            # Decode
            decoded_image = vqvae.decode(image_tokens.reshape(1, 7, 7))[0]
            decoded_text = text_tokenizer.text_decode(text_tokens)
            
            samples.append((decoded_image, decoded_text))
    
    return samples
In [34]:
def q6_a(train_data, test_data, image_shape, train_text, test_text, image_test_prompt, text_test_prompt, vqvae):
    """
    train_data: A (n_train, H, W, C) uint8 numpy array of color images with values in {0, 1, 2, 3}
    test_data: A (n_test, H, W, C) uint8 numpy array of color images with values in {0, 1, 2, 3}
    image_shape: tuple (H, W, C) The shape of the images in the dataset, indicating height, width, and number of color channels.
    train_text: list[str] Text data associated with each training image.
    test_text: list[str] Text data associated with each test image.
    image_test_prompt: (9, H, W, C) Image data used for generating conditional text samples during testing.
    text_test_prompt: list of 9 strings Text prompts used for generating conditional image samples during testing.
    vqvae: a vqvae model, trained on the relevant dataset

    Returns
    - a (# of training iterations,) numpy array of train_losses evaluated every minibatch
    - a (# of epochs + 1,) numpy array of test_losses evaluated once at initialization and after each epoch
    - a list of 9 (image, text), corresponding to the image conditioned samples
    - a list of 9 (image, text), corresponding to the text conditions samples
    - a list of 9 (image, text), corresponding to unconditional samples
    """
    # Fix the offset parameter for the tokenizer - it should be the vocab_size, not 0
    text_tokenizer = Tokenizer(train_text, vqvae.n_embeddings)
    
    H, W, C = image_shape
    batch_size = 128
    learning_rate = 1e-3
    num_epochs = 30
    d_model = 128
    n_heads = 4
    n_layers = 4
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # determine sequence length and vocab size
    sequence_length = 58 # 49 + 6 +2 + 1
    # Total vocab size should include both image tokens and text tokens
    total_vocab_size = vqvae.n_embeddings + len(text_tokenizer.all_words)
    # get subset of data to test first 
    train_loader = create_dataset(train_data, train_text, vqvae, text_tokenizer, batch_size)
    test_loader = create_dataset(test_data, test_text, vqvae, text_tokenizer, batch_size)
    
    model = iGPT(total_vocab_size, sequence_length, d_model, n_heads, n_layers).to(device)
    train_losses, test_losses = train_igpt(model, train_loader, test_loader, 
                                            sequence_length, total_vocab_size, device,
                                            num_epochs, learning_rate)
    

    # Generate samples
    samples_text_conditioned = generate_conditional_samples_from_text(
        model, text_tokenizer, vqvae, text_test_prompt, device
    )
    
    samples_image_conditioned = generate_conditional_samples_from_image(
        model, text_tokenizer, vqvae, image_test_prompt, device
    )
    
    samples_unconditioned = generate_unconditional_samples(
        model, text_tokenizer, vqvae, device, num_samples=9
    )
    return train_losses, test_losses, samples_image_conditioned, samples_text_conditioned, samples_unconditioned

Results¶

Once you've implemented q6_a, execute the cells below to visualize and save your results

In [35]:
q6a_save_results(q6_a)
data_dir:  /home/nghiaph/workspace/deepul/homeworks/hw1/data
data_dir:  /home/nghiaph/workspace/deepul/homeworks/hw1/data
Creating dataset from 60000 samples...
Pre-tokenizing all text data...
Batch processing images...
Processed 128/60000 images (2.1%)
Assembling dataset...
Dataset creation complete! Total samples: 60000
Creating DataLoader with batch_size=128
Creating dataset from 10000 samples...
Pre-tokenizing all text data...
Batch processing images...
Processed 128/10000 images (12.8%)
Assembling dataset...
Dataset creation complete! Total samples: 10000
Creating DataLoader with batch_size=128
Epoch 1/30, Batch 0/469, Loss: 7.0951
Epoch 1/30, Batch 50/469, Loss: 6.9365
Epoch 1/30, Batch 100/469, Loss: 6.4238
Epoch 1/30, Batch 150/469, Loss: 5.9999
Epoch 1/30, Batch 200/469, Loss: 5.5952
Epoch 1/30, Batch 250/469, Loss: 5.0091
Epoch 1/30, Batch 300/469, Loss: 4.3433
Epoch 1/30, Batch 350/469, Loss: 3.9322
Epoch 1/30, Batch 400/469, Loss: 3.8381
Epoch 1/30, Batch 450/469, Loss: 3.7479
Epoch 1/30 completed. Test Loss: 3.6811
Epoch 2/30, Batch 0/469, Loss: 3.7475
Epoch 2/30, Batch 50/469, Loss: 3.6743
Epoch 2/30, Batch 100/469, Loss: 3.6187
Epoch 2/30, Batch 150/469, Loss: 3.6176
Epoch 2/30, Batch 200/469, Loss: 3.5434
Epoch 2/30, Batch 250/469, Loss: 3.6104
Epoch 2/30, Batch 300/469, Loss: 3.5620
Epoch 2/30, Batch 350/469, Loss: 3.4631
Epoch 2/30, Batch 400/469, Loss: 3.4442
Epoch 2/30, Batch 450/469, Loss: 3.4916
Epoch 2/30 completed. Test Loss: 3.3927
Epoch 3/30, Batch 0/469, Loss: 3.4020
Epoch 3/30, Batch 50/469, Loss: 3.4850
Epoch 3/30, Batch 100/469, Loss: 3.3626
Epoch 3/30, Batch 150/469, Loss: 3.3485
Epoch 3/30, Batch 200/469, Loss: 3.3734
Epoch 3/30, Batch 250/469, Loss: 3.2823
Epoch 3/30, Batch 300/469, Loss: 3.2191
Epoch 3/30, Batch 350/469, Loss: 3.2746
Epoch 3/30, Batch 400/469, Loss: 3.3012
Epoch 3/30, Batch 450/469, Loss: 3.2492
Epoch 3/30 completed. Test Loss: 3.1725
Epoch 4/30, Batch 0/469, Loss: 3.1248
Epoch 4/30, Batch 50/469, Loss: 3.2201
Epoch 4/30, Batch 100/469, Loss: 3.2481
Epoch 4/30, Batch 150/469, Loss: 3.2488
Epoch 4/30, Batch 200/469, Loss: 3.1117
Epoch 4/30, Batch 250/469, Loss: 3.0753
Epoch 4/30, Batch 300/469, Loss: 3.1594
Epoch 4/30, Batch 350/469, Loss: 3.1258
Epoch 4/30, Batch 400/469, Loss: 3.0861
Epoch 4/30, Batch 450/469, Loss: 3.0951
Epoch 4/30 completed. Test Loss: 3.0241
Epoch 5/30, Batch 0/469, Loss: 3.0616
Epoch 5/30, Batch 50/469, Loss: 3.0368
Epoch 5/30, Batch 100/469, Loss: 3.0523
Epoch 5/30, Batch 150/469, Loss: 3.0173
Epoch 5/30, Batch 200/469, Loss: 2.9779
Epoch 5/30, Batch 250/469, Loss: 3.0774
Epoch 5/30, Batch 300/469, Loss: 3.0289
Epoch 5/30, Batch 350/469, Loss: 3.0879
Epoch 5/30, Batch 400/469, Loss: 2.9733
Epoch 5/30, Batch 450/469, Loss: 3.0253
Epoch 5/30 completed. Test Loss: 2.9287
Epoch 6/30, Batch 0/469, Loss: 2.9740
Epoch 6/30, Batch 50/469, Loss: 2.9356
Epoch 6/30, Batch 100/469, Loss: 3.0952
Epoch 6/30, Batch 150/469, Loss: 2.9799
Epoch 6/30, Batch 200/469, Loss: 2.9494
Epoch 6/30, Batch 250/469, Loss: 2.9638
Epoch 6/30, Batch 300/469, Loss: 2.9206
Epoch 6/30, Batch 350/469, Loss: 2.9036
Epoch 6/30, Batch 400/469, Loss: 3.0060
Epoch 6/30, Batch 450/469, Loss: 2.9352
Epoch 6/30 completed. Test Loss: 2.8623
Epoch 7/30, Batch 0/469, Loss: 2.9634
Epoch 7/30, Batch 50/469, Loss: 2.9687
Epoch 7/30, Batch 100/469, Loss: 2.8994
Epoch 7/30, Batch 150/469, Loss: 2.9184
Epoch 7/30, Batch 200/469, Loss: 2.8395
Epoch 7/30, Batch 250/469, Loss: 2.9313
Epoch 7/30, Batch 300/469, Loss: 2.8314
Epoch 7/30, Batch 350/469, Loss: 2.9302
Epoch 7/30, Batch 400/469, Loss: 2.8441
Epoch 7/30, Batch 450/469, Loss: 2.8525
Epoch 7/30 completed. Test Loss: 2.8188
Epoch 8/30, Batch 0/469, Loss: 2.8840
Epoch 8/30, Batch 50/469, Loss: 2.8966
Epoch 8/30, Batch 100/469, Loss: 2.8827
Epoch 8/30, Batch 150/469, Loss: 2.8918
Epoch 8/30, Batch 200/469, Loss: 2.8461
Epoch 8/30, Batch 250/469, Loss: 2.9240
Epoch 8/30, Batch 300/469, Loss: 2.8935
Epoch 8/30, Batch 350/469, Loss: 2.8889
Epoch 8/30, Batch 400/469, Loss: 2.7807
Epoch 8/30, Batch 450/469, Loss: 2.8786
Epoch 8/30 completed. Test Loss: 2.7761
Epoch 9/30, Batch 0/469, Loss: 2.8313
Epoch 9/30, Batch 50/469, Loss: 2.8219
Epoch 9/30, Batch 100/469, Loss: 2.8402
Epoch 9/30, Batch 150/469, Loss: 2.8810
Epoch 9/30, Batch 200/469, Loss: 2.7957
Epoch 9/30, Batch 250/469, Loss: 2.8398
Epoch 9/30, Batch 300/469, Loss: 2.9297
Epoch 9/30, Batch 350/469, Loss: 2.9504
Epoch 9/30, Batch 400/469, Loss: 2.8850
Epoch 9/30, Batch 450/469, Loss: 2.8889
Epoch 9/30 completed. Test Loss: 2.7530
Epoch 10/30, Batch 0/469, Loss: 2.7714
Epoch 10/30, Batch 50/469, Loss: 2.8695
Epoch 10/30, Batch 100/469, Loss: 2.7365
Epoch 10/30, Batch 150/469, Loss: 2.8018
Epoch 10/30, Batch 200/469, Loss: 2.7453
Epoch 10/30, Batch 250/469, Loss: 2.8007
Epoch 10/30, Batch 300/469, Loss: 2.8421
Epoch 10/30, Batch 350/469, Loss: 2.8254
Epoch 10/30, Batch 400/469, Loss: 2.7566
Epoch 10/30, Batch 450/469, Loss: 2.8254
Epoch 10/30 completed. Test Loss: 2.7228
Epoch 11/30, Batch 0/469, Loss: 2.7551
Epoch 11/30, Batch 50/469, Loss: 2.8213
Epoch 11/30, Batch 100/469, Loss: 2.8193
Epoch 11/30, Batch 150/469, Loss: 2.7120
Epoch 11/30, Batch 200/469, Loss: 2.7770
Epoch 11/30, Batch 250/469, Loss: 2.7634
Epoch 11/30, Batch 300/469, Loss: 2.7723
Epoch 11/30, Batch 350/469, Loss: 2.7986
Epoch 11/30, Batch 400/469, Loss: 2.6951
Epoch 11/30, Batch 450/469, Loss: 2.8092
Epoch 11/30 completed. Test Loss: 2.7056
Epoch 12/30, Batch 0/469, Loss: 2.6845
Epoch 12/30, Batch 50/469, Loss: 2.7566
Epoch 12/30, Batch 100/469, Loss: 2.7650
Epoch 12/30, Batch 150/469, Loss: 2.7342
Epoch 12/30, Batch 200/469, Loss: 2.7513
Epoch 12/30, Batch 250/469, Loss: 2.7280
Epoch 12/30, Batch 300/469, Loss: 2.8334
Epoch 12/30, Batch 350/469, Loss: 2.6839
Epoch 12/30, Batch 400/469, Loss: 2.7525
Epoch 12/30, Batch 450/469, Loss: 2.7641
Epoch 12/30 completed. Test Loss: 2.6857
Epoch 13/30, Batch 0/469, Loss: 2.7559
Epoch 13/30, Batch 50/469, Loss: 2.6989
Epoch 13/30, Batch 100/469, Loss: 2.7418
Epoch 13/30, Batch 150/469, Loss: 2.7183
Epoch 13/30, Batch 200/469, Loss: 2.7075
Epoch 13/30, Batch 250/469, Loss: 2.7152
Epoch 13/30, Batch 300/469, Loss: 2.6814
Epoch 13/30, Batch 350/469, Loss: 2.7818
Epoch 13/30, Batch 400/469, Loss: 2.6887
Epoch 13/30, Batch 450/469, Loss: 2.7272
Epoch 13/30 completed. Test Loss: 2.6783
Epoch 14/30, Batch 0/469, Loss: 2.7592
Epoch 14/30, Batch 50/469, Loss: 2.6698
Epoch 14/30, Batch 100/469, Loss: 2.6606
Epoch 14/30, Batch 150/469, Loss: 2.6669
Epoch 14/30, Batch 200/469, Loss: 2.7069
Epoch 14/30, Batch 250/469, Loss: 2.6557
Epoch 14/30, Batch 300/469, Loss: 2.7253
Epoch 14/30, Batch 350/469, Loss: 2.7223
Epoch 14/30, Batch 400/469, Loss: 2.6526
Epoch 14/30, Batch 450/469, Loss: 2.7666
Epoch 14/30 completed. Test Loss: 2.6631
Epoch 15/30, Batch 0/469, Loss: 2.8077
Epoch 15/30, Batch 50/469, Loss: 2.7296
Epoch 15/30, Batch 100/469, Loss: 2.7253
Epoch 15/30, Batch 150/469, Loss: 2.7593
Epoch 15/30, Batch 200/469, Loss: 2.6732
Epoch 15/30, Batch 250/469, Loss: 2.6974
Epoch 15/30, Batch 300/469, Loss: 2.7120
Epoch 15/30, Batch 350/469, Loss: 2.7051
Epoch 15/30, Batch 400/469, Loss: 2.7432
Epoch 15/30, Batch 450/469, Loss: 2.7349
Epoch 15/30 completed. Test Loss: 2.6495
Epoch 16/30, Batch 0/469, Loss: 2.7292
Epoch 16/30, Batch 50/469, Loss: 2.6720
Epoch 16/30, Batch 100/469, Loss: 2.6532
Epoch 16/30, Batch 150/469, Loss: 2.7154
Epoch 16/30, Batch 200/469, Loss: 2.7005
Epoch 16/30, Batch 250/469, Loss: 2.6644
Epoch 16/30, Batch 300/469, Loss: 2.6986
Epoch 16/30, Batch 350/469, Loss: 2.7245
Epoch 16/30, Batch 400/469, Loss: 2.6717
Epoch 16/30, Batch 450/469, Loss: 2.6643
Epoch 16/30 completed. Test Loss: 2.6388
Epoch 17/30, Batch 0/469, Loss: 2.6515
Epoch 17/30, Batch 50/469, Loss: 2.6336
Epoch 17/30, Batch 100/469, Loss: 2.6795
Epoch 17/30, Batch 150/469, Loss: 2.6871
Epoch 17/30, Batch 200/469, Loss: 2.7344
Epoch 17/30, Batch 250/469, Loss: 2.6723
Epoch 17/30, Batch 300/469, Loss: 2.7224
Epoch 17/30, Batch 350/469, Loss: 2.6828
Epoch 17/30, Batch 400/469, Loss: 2.7290
Epoch 17/30, Batch 450/469, Loss: 2.6904
Epoch 17/30 completed. Test Loss: 2.6326
Epoch 18/30, Batch 0/469, Loss: 2.7198
Epoch 18/30, Batch 50/469, Loss: 2.6155
Epoch 18/30, Batch 100/469, Loss: 2.6426
Epoch 18/30, Batch 150/469, Loss: 2.6718
Epoch 18/30, Batch 200/469, Loss: 2.6358
Epoch 18/30, Batch 250/469, Loss: 2.6954
Epoch 18/30, Batch 300/469, Loss: 2.7013
Epoch 18/30, Batch 350/469, Loss: 2.6637
Epoch 18/30, Batch 400/469, Loss: 2.6466
Epoch 18/30, Batch 450/469, Loss: 2.6998
Epoch 18/30 completed. Test Loss: 2.6201
Epoch 19/30, Batch 0/469, Loss: 2.6950
Epoch 19/30, Batch 50/469, Loss: 2.5979
Epoch 19/30, Batch 100/469, Loss: 2.6393
Epoch 19/30, Batch 150/469, Loss: 2.6538
Epoch 19/30, Batch 200/469, Loss: 2.7432
Epoch 19/30, Batch 250/469, Loss: 2.6518
Epoch 19/30, Batch 300/469, Loss: 2.6267
Epoch 19/30, Batch 350/469, Loss: 2.6606
Epoch 19/30, Batch 400/469, Loss: 2.5855
Epoch 19/30, Batch 450/469, Loss: 2.6239
Epoch 19/30 completed. Test Loss: 2.6183
Epoch 20/30, Batch 0/469, Loss: 2.6392
Epoch 20/30, Batch 50/469, Loss: 2.6116
Epoch 20/30, Batch 100/469, Loss: 2.6269
Epoch 20/30, Batch 150/469, Loss: 2.7098
Epoch 20/30, Batch 200/469, Loss: 2.6827
Epoch 20/30, Batch 250/469, Loss: 2.6657
Epoch 20/30, Batch 300/469, Loss: 2.6737
Epoch 20/30, Batch 350/469, Loss: 2.6417
Epoch 20/30, Batch 400/469, Loss: 2.6517
Epoch 20/30, Batch 450/469, Loss: 2.5823
Epoch 20/30 completed. Test Loss: 2.6076
Epoch 21/30, Batch 0/469, Loss: 2.6687
Epoch 21/30, Batch 50/469, Loss: 2.6462
Epoch 21/30, Batch 100/469, Loss: 2.6730
Epoch 21/30, Batch 150/469, Loss: 2.6893
Epoch 21/30, Batch 200/469, Loss: 2.6747
Epoch 21/30, Batch 250/469, Loss: 2.7593
Epoch 21/30, Batch 300/469, Loss: 2.6666
Epoch 21/30, Batch 350/469, Loss: 2.7037
Epoch 21/30, Batch 400/469, Loss: 2.6659
Epoch 21/30, Batch 450/469, Loss: 2.6344
Epoch 21/30 completed. Test Loss: 2.6016
Epoch 22/30, Batch 0/469, Loss: 2.5732
Epoch 22/30, Batch 50/469, Loss: 2.6851
Epoch 22/30, Batch 100/469, Loss: 2.6935
Epoch 22/30, Batch 150/469, Loss: 2.7070
Epoch 22/30, Batch 200/469, Loss: 2.6485
Epoch 22/30, Batch 250/469, Loss: 2.6468
Epoch 22/30, Batch 300/469, Loss: 2.5894
Epoch 22/30, Batch 350/469, Loss: 2.6464
Epoch 22/30, Batch 400/469, Loss: 2.6810
Epoch 22/30, Batch 450/469, Loss: 2.6639
Epoch 22/30 completed. Test Loss: 2.5993
Epoch 23/30, Batch 0/469, Loss: 2.5770
Epoch 23/30, Batch 50/469, Loss: 2.6658
Epoch 23/30, Batch 100/469, Loss: 2.6333
Epoch 23/30, Batch 150/469, Loss: 2.6432
Epoch 23/30, Batch 200/469, Loss: 2.7187
Epoch 23/30, Batch 250/469, Loss: 2.6751
Epoch 23/30, Batch 300/469, Loss: 2.6216
Epoch 23/30, Batch 350/469, Loss: 2.6480
Epoch 23/30, Batch 400/469, Loss: 2.5892
Epoch 23/30, Batch 450/469, Loss: 2.7111
Epoch 23/30 completed. Test Loss: 2.5942
Epoch 24/30, Batch 0/469, Loss: 2.5932
Epoch 24/30, Batch 50/469, Loss: 2.6080
Epoch 24/30, Batch 100/469, Loss: 2.5952
Epoch 24/30, Batch 150/469, Loss: 2.5781
Epoch 24/30, Batch 200/469, Loss: 2.6943
Epoch 24/30, Batch 250/469, Loss: 2.6286
Epoch 24/30, Batch 300/469, Loss: 2.6762
Epoch 24/30, Batch 350/469, Loss: 2.6617
Epoch 24/30, Batch 400/469, Loss: 2.6114
Epoch 24/30, Batch 450/469, Loss: 2.6298
Epoch 24/30 completed. Test Loss: 2.5928
Epoch 25/30, Batch 0/469, Loss: 2.6036
Epoch 25/30, Batch 50/469, Loss: 2.7137
Epoch 25/30, Batch 100/469, Loss: 2.6335
Epoch 25/30, Batch 150/469, Loss: 2.5694
Epoch 25/30, Batch 200/469, Loss: 2.6246
Epoch 25/30, Batch 250/469, Loss: 2.5550
Epoch 25/30, Batch 300/469, Loss: 2.5811
Epoch 25/30, Batch 350/469, Loss: 2.6790
Epoch 25/30, Batch 400/469, Loss: 2.6278
Epoch 25/30, Batch 450/469, Loss: 2.6187
Epoch 25/30 completed. Test Loss: 2.5892
Epoch 26/30, Batch 0/469, Loss: 2.6463
Epoch 26/30, Batch 50/469, Loss: 2.6439
Epoch 26/30, Batch 100/469, Loss: 2.6300
Epoch 26/30, Batch 150/469, Loss: 2.5776
Epoch 26/30, Batch 200/469, Loss: 2.6629
Epoch 26/30, Batch 250/469, Loss: 2.6139
Epoch 26/30, Batch 300/469, Loss: 2.6144
Epoch 26/30, Batch 350/469, Loss: 2.6701
Epoch 26/30, Batch 400/469, Loss: 2.5264
Epoch 26/30, Batch 450/469, Loss: 2.5714
Epoch 26/30 completed. Test Loss: 2.5870
Epoch 27/30, Batch 0/469, Loss: 2.6233
Epoch 27/30, Batch 50/469, Loss: 2.6410
Epoch 27/30, Batch 100/469, Loss: 2.6642
Epoch 27/30, Batch 150/469, Loss: 2.5993
Epoch 27/30, Batch 200/469, Loss: 2.6415
Epoch 27/30, Batch 250/469, Loss: 2.6294
Epoch 27/30, Batch 300/469, Loss: 2.5992
Epoch 27/30, Batch 350/469, Loss: 2.6654
Epoch 27/30, Batch 400/469, Loss: 2.4930
Epoch 27/30, Batch 450/469, Loss: 2.6011
Epoch 27/30 completed. Test Loss: 2.5856
Epoch 28/30, Batch 0/469, Loss: 2.6330
Epoch 28/30, Batch 50/469, Loss: 2.6360
Epoch 28/30, Batch 100/469, Loss: 2.6049
Epoch 28/30, Batch 150/469, Loss: 2.6217
Epoch 28/30, Batch 200/469, Loss: 2.5843
Epoch 28/30, Batch 250/469, Loss: 2.6726
Epoch 28/30, Batch 300/469, Loss: 2.6267
Epoch 28/30, Batch 350/469, Loss: 2.5920
Epoch 28/30, Batch 400/469, Loss: 2.6712
Epoch 28/30, Batch 450/469, Loss: 2.6426
Epoch 28/30 completed. Test Loss: 2.5855
Epoch 29/30, Batch 0/469, Loss: 2.6381
Epoch 29/30, Batch 50/469, Loss: 2.6717
Epoch 29/30, Batch 100/469, Loss: 2.5929
Epoch 29/30, Batch 150/469, Loss: 2.6720
Epoch 29/30, Batch 200/469, Loss: 2.5194
Epoch 29/30, Batch 250/469, Loss: 2.6219
Epoch 29/30, Batch 300/469, Loss: 2.5707
Epoch 29/30, Batch 350/469, Loss: 2.5919
Epoch 29/30, Batch 400/469, Loss: 2.6117
Epoch 29/30, Batch 450/469, Loss: 2.5315
Epoch 29/30 completed. Test Loss: 2.5852
Epoch 30/30, Batch 0/469, Loss: 2.5599
Epoch 30/30, Batch 50/469, Loss: 2.6249
Epoch 30/30, Batch 100/469, Loss: 2.5602
Epoch 30/30, Batch 150/469, Loss: 2.5848
Epoch 30/30, Batch 200/469, Loss: 2.6370
Epoch 30/30, Batch 250/469, Loss: 2.5630
Epoch 30/30, Batch 300/469, Loss: 2.6131
Epoch 30/30, Batch 350/469, Loss: 2.6885
Epoch 30/30, Batch 400/469, Loss: 2.6537
Epoch 30/30, Batch 450/469, Loss: 2.6980
Epoch 30/30 completed. Test Loss: 2.5852
/tmp/ipykernel_2898721/2746847972.py:82: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  torch.tensor(image_tokens),
Final Test Loss: 2.5852
No description has been provided for this image
No description has been provided for this image
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-10..262].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-11..177].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-9..266].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-39..306].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-10..259].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-10..266].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-9..266].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [78..266].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [37..275].
No description has been provided for this image
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-27..319].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-3..224].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-48..294].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-9..262].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-6..286].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [75..266].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-23..286].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [38..268].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [18..259].
No description has been provided for this image